mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-25 19:03:43 +08:00 
			
		
		
		
	Compare commits
	
		
			22 Commits
		
	
	
		
			v0.6.6-alp
			...
			v0.6.6-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | e64e7707a0 | ||
|  | ea210b6ed7 | ||
|  | 9026ec7510 | ||
|  | c317872097 | ||
|  | da0842272c | ||
|  | 0a650b85b4 | ||
|  | 24f026d18e | ||
|  | cb33e8aad5 | ||
|  | 779b747e9e | ||
|  | 3d149fedf4 | ||
|  | 83517f687c | ||
|  | e30ebda0fe | ||
|  | d87c55f542 | ||
|  | e5b3e37c46 | ||
|  | 8de489cf06 | ||
|  | d14e4aa01b | ||
|  | 541182102e | ||
|  | b2679cca65 | ||
|  | 8572fac7a2 | ||
|  | a2a00dfbc3 | ||
|  | 129282f4a9 | ||
|  | a873cbd392 | 
| @@ -82,6 +82,10 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用  | ||||
|    + [x] [Ollama](https://github.com/ollama/ollama) | ||||
|    + [x] [零一万物](https://platform.lingyiwanwu.com/) | ||||
|    + [x] [阶跃星辰](https://platform.stepfun.com/) | ||||
|    + [x] [Coze](https://www.coze.com/) | ||||
|    + [x] [Cohere](https://cohere.com/) | ||||
|    + [x] [DeepSeek](https://www.deepseek.com/) | ||||
|    + [x] [Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/) | ||||
| 2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。 | ||||
| 3. 支持通过**负载均衡**的方式访问多个渠道。 | ||||
| 4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。 | ||||
|   | ||||
| @@ -1,12 +0,0 @@ | ||||
| package config | ||||
|  | ||||
| const ( | ||||
| 	KeyPrefix = "cfg_" | ||||
|  | ||||
| 	KeyAPIVersion = KeyPrefix + "api_version" | ||||
| 	KeyLibraryID  = KeyPrefix + "library_id" | ||||
| 	KeyPlugin     = KeyPrefix + "plugin" | ||||
| 	KeySK         = KeyPrefix + "sk" | ||||
| 	KeyAK         = KeyPrefix + "ak" | ||||
| 	KeyRegion     = KeyPrefix + "region" | ||||
| ) | ||||
| @@ -1,7 +1,22 @@ | ||||
| package ctxkey | ||||
|  | ||||
| var ( | ||||
| 	RequestModel     = "request_model" | ||||
| 	ConvertedRequest = "converted_request" | ||||
| 	OriginalModel    = "original_model" | ||||
| const ( | ||||
| 	Config            = "config" | ||||
| 	Id                = "id" | ||||
| 	Username          = "username" | ||||
| 	Role              = "role" | ||||
| 	Status            = "status" | ||||
| 	Channel           = "channel" | ||||
| 	ChannelId         = "channel_id" | ||||
| 	SpecificChannelId = "specific_channel_id" | ||||
| 	RequestModel      = "request_model" | ||||
| 	ConvertedRequest  = "converted_request" | ||||
| 	OriginalModel     = "original_model" | ||||
| 	Group             = "group" | ||||
| 	ModelMapping      = "model_mapping" | ||||
| 	ChannelName       = "channel_name" | ||||
| 	TokenId           = "token_id" | ||||
| 	TokenName         = "token_name" | ||||
| 	BaseURL           = "base_url" | ||||
| 	AvailableModels   = "available_models" | ||||
| ) | ||||
|   | ||||
| @@ -2,6 +2,7 @@ package helper | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"html/template" | ||||
| 	"log" | ||||
| @@ -105,6 +106,11 @@ func GenRequestID() string { | ||||
| 	return GetTimeString() + random.GetRandomNumberString(8) | ||||
| } | ||||
|  | ||||
| func GetResponseID(c *gin.Context) string { | ||||
| 	logID := c.GetString(RequestIdKey) | ||||
| 	return fmt.Sprintf("chatcmpl-%s", logID) | ||||
| } | ||||
|  | ||||
| func Max(a int, b int) int { | ||||
| 	if a >= b { | ||||
| 		return a | ||||
|   | ||||
							
								
								
									
										5
									
								
								common/helper/key.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								common/helper/key.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package helper | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
| @@ -16,7 +16,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| // Regex to match data URL pattern | ||||
| var	dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) | ||||
| var dataURLPattern = regexp.MustCompile(`data:image/([^;]+);base64,(.*)`) | ||||
|  | ||||
| func IsImageUrl(url string) (bool, error) { | ||||
| 	resp, err := http.Head(url) | ||||
|   | ||||
| @@ -1,7 +1,3 @@ | ||||
| package logger | ||||
|  | ||||
| const ( | ||||
| 	RequestIdKey = "X-Oneapi-Request-Id" | ||||
| ) | ||||
|  | ||||
| var LogDir string | ||||
|   | ||||
| @@ -3,15 +3,16 @@ package logger | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"io" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -21,28 +22,20 @@ const ( | ||||
| 	loggerError = "ERR" | ||||
| ) | ||||
|  | ||||
| var setupLogLock sync.Mutex | ||||
| var setupLogWorking bool | ||||
| var setupLogOnce sync.Once | ||||
|  | ||||
| func SetupLogger() { | ||||
| 	if LogDir != "" { | ||||
| 		ok := setupLogLock.TryLock() | ||||
| 		if !ok { | ||||
| 			log.Println("setup log is already working") | ||||
| 			return | ||||
| 	setupLogOnce.Do(func() { | ||||
| 		if LogDir != "" { | ||||
| 			logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 			fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||
| 			if err != nil { | ||||
| 				log.Fatal("failed to open log file") | ||||
| 			} | ||||
| 			gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) | ||||
| 			gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) | ||||
| 		} | ||||
| 		defer func() { | ||||
| 			setupLogLock.Unlock() | ||||
| 			setupLogWorking = false | ||||
| 		}() | ||||
| 		logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) | ||||
| 		fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) | ||||
| 		if err != nil { | ||||
| 			log.Fatal("failed to open log file") | ||||
| 		} | ||||
| 		gin.DefaultWriter = io.MultiWriter(os.Stdout, fd) | ||||
| 		gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd) | ||||
| 	} | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func SysLog(s string) { | ||||
| @@ -94,18 +87,13 @@ func logHelper(ctx context.Context, level string, msg string) { | ||||
| 	if level == loggerINFO { | ||||
| 		writer = gin.DefaultWriter | ||||
| 	} | ||||
| 	id := ctx.Value(RequestIdKey) | ||||
| 	id := ctx.Value(helper.RequestIdKey) | ||||
| 	if id == nil { | ||||
| 		id = helper.GenRequestID() | ||||
| 	} | ||||
| 	now := time.Now() | ||||
| 	_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg) | ||||
| 	if !setupLogWorking { | ||||
| 		setupLogWorking = true | ||||
| 		go func() { | ||||
| 			SetupLogger() | ||||
| 		}() | ||||
| 	} | ||||
| 	SetupLogger() | ||||
| } | ||||
|  | ||||
| func FatalLog(v ...any) { | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/controller" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| @@ -136,7 +137,7 @@ func WeChatBind(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	id := c.GetInt("id") | ||||
| 	id := c.GetInt(ctxkey.Id) | ||||
| 	user := model.User{ | ||||
| 		Id: id, | ||||
| 	} | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package controller | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
| @@ -14,13 +15,13 @@ func GetSubscription(c *gin.Context) { | ||||
| 	var token *model.Token | ||||
| 	var expiredTime int64 | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		expiredTime = token.ExpiredTime | ||||
| 		remainQuota = token.RemainQuota | ||||
| 		usedQuota = token.UsedQuota | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		remainQuota, err = model.GetUserQuota(userId) | ||||
| 		if err != nil { | ||||
| 			usedQuota, err = model.GetUserUsedQuota(userId) | ||||
| @@ -64,11 +65,11 @@ func GetUsage(c *gin.Context) { | ||||
| 	var err error | ||||
| 	var token *model.Token | ||||
| 	if config.DisplayTokenStatEnabled { | ||||
| 		tokenId := c.GetInt("token_id") | ||||
| 		tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 		token, err = model.GetTokenById(tokenId) | ||||
| 		quota = token.UsedQuota | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		quota, err = model.GetUserUsedQuota(userId) | ||||
| 	} | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -6,6 +6,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/message" | ||||
| 	"github.com/songquanpeng/one-api/middleware" | ||||
| @@ -54,8 +55,10 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	} | ||||
| 	c.Request.Header.Set("Authorization", "Bearer "+channel.Key) | ||||
| 	c.Request.Header.Set("Content-Type", "application/json") | ||||
| 	c.Set("channel", channel.Type) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	c.Set(ctxkey.Channel, channel.Type) | ||||
| 	c.Set(ctxkey.BaseURL, channel.GetBaseURL()) | ||||
| 	cfg, _ := channel.LoadConfig() | ||||
| 	c.Set(ctxkey.Config, cfg) | ||||
| 	middleware.SetupContextForSelectedChannel(c, channel, "") | ||||
| 	meta := meta.GetByContext(c) | ||||
| 	apiType := channeltype.ToAPIType(channel.Type) | ||||
| @@ -64,8 +67,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
| 	modelName := adaptor.GetModelList()[0] | ||||
| 	if !strings.Contains(channel.Models, modelName) { | ||||
| 	var modelName string | ||||
| 	modelList := adaptor.GetModelList() | ||||
| 	if len(modelList) != 0 { | ||||
| 		modelName = modelList[0] | ||||
| 	} | ||||
| 	if modelName == "" || !strings.Contains(channel.Models, modelName) { | ||||
| 		modelNames := strings.Split(channel.Models, ",") | ||||
| 		if len(modelNames) > 0 { | ||||
| 			modelName = modelNames[0] | ||||
| @@ -82,6 +89,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error | ||||
| 	if err != nil { | ||||
| 		return err, nil | ||||
| 	} | ||||
| 	logger.SysLog(string(jsonData)) | ||||
| 	requestBody := bytes.NewBuffer(jsonData) | ||||
| 	c.Request.Body = io.NopCloser(requestBody) | ||||
| 	resp, err := adaptor.DoRequest(c, meta, requestBody) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package controller | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| @@ -41,7 +42,7 @@ func GetUserLogs(c *gin.Context) { | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| 	} | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | ||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
| @@ -83,7 +84,7 @@ func SearchAllLogs(c *gin.Context) { | ||||
|  | ||||
| func SearchUserLogs(c *gin.Context) { | ||||
| 	keyword := c.Query("keyword") | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	logs, err := model.SearchUserLogs(userId, keyword) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -122,7 +123,7 @@ func GetLogsStat(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetLogsSelfStat(c *gin.Context) { | ||||
| 	username := c.GetString("username") | ||||
| 	username := c.GetString(ctxkey.Username) | ||||
| 	logType, _ := strconv.Atoi(c.Query("type")) | ||||
| 	startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64) | ||||
| 	endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package controller | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	relay "github.com/songquanpeng/one-api/relay" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| @@ -131,10 +132,10 @@ func ListAllModels(c *gin.Context) { | ||||
| func ListModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	var availableModels []string | ||||
| 	if c.GetString("available_models") != "" { | ||||
| 		availableModels = strings.Split(c.GetString("available_models"), ",") | ||||
| 	if c.GetString(ctxkey.AvailableModels) != "" { | ||||
| 		availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") | ||||
| 	} else { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) | ||||
| 	} | ||||
| @@ -186,7 +187,7 @@ func RetrieveModel(c *gin.Context) { | ||||
|  | ||||
| func GetUserAvailableModels(c *gin.Context) { | ||||
| 	ctx := c.Request.Context() | ||||
| 	id := c.GetInt("id") | ||||
| 	id := c.GetInt(ctxkey.Id) | ||||
| 	userGroup, err := model.CacheGetUserGroup(id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package controller | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| @@ -109,7 +110,7 @@ func AddRedemption(c *gin.Context) { | ||||
| 	for i := 0; i < redemption.Count; i++ { | ||||
| 		key := random.GetUUID() | ||||
| 		cleanRedemption := model.Redemption{ | ||||
| 			UserId:      c.GetInt("id"), | ||||
| 			UserId:      c.GetInt(ctxkey.Id), | ||||
| 			Name:        redemption.Name, | ||||
| 			Key:         key, | ||||
| 			CreatedTime: helper.GetTimestamp(), | ||||
|   | ||||
| @@ -46,18 +46,18 @@ func Relay(c *gin.Context) { | ||||
| 		requestBody, _ := common.GetRequestBody(c) | ||||
| 		logger.Debugf(ctx, "request body: %s", string(requestBody)) | ||||
| 	} | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 	bizErr := relayHelper(c, relayMode) | ||||
| 	if bizErr == nil { | ||||
| 		monitor.Emit(channelId, true) | ||||
| 		return | ||||
| 	} | ||||
| 	lastFailedChannelId := channelId | ||||
| 	channelName := c.GetString("channel_name") | ||||
| 	group := c.GetString("group") | ||||
| 	channelName := c.GetString(ctxkey.ChannelName) | ||||
| 	group := c.GetString(ctxkey.Group) | ||||
| 	originalModel := c.GetString(ctxkey.OriginalModel) | ||||
| 	go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||
| 	requestId := c.GetString(logger.RequestIdKey) | ||||
| 	requestId := c.GetString(helper.RequestIdKey) | ||||
| 	retryTimes := config.RetryTimes | ||||
| 	if !shouldRetry(c, bizErr.StatusCode) { | ||||
| 		logger.Errorf(ctx, "relay error happen, status code is %d, won't retry in this case", bizErr.StatusCode) | ||||
| @@ -80,9 +80,9 @@ func Relay(c *gin.Context) { | ||||
| 		if bizErr == nil { | ||||
| 			return | ||||
| 		} | ||||
| 		channelId := c.GetInt("channel_id") | ||||
| 		channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 		lastFailedChannelId = channelId | ||||
| 		channelName := c.GetString("channel_name") | ||||
| 		channelName := c.GetString(ctxkey.ChannelName) | ||||
| 		go processChannelRelayError(ctx, channelId, channelName, bizErr) | ||||
| 	} | ||||
| 	if bizErr != nil { | ||||
| @@ -97,7 +97,7 @@ func Relay(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func shouldRetry(c *gin.Context, statusCode int) bool { | ||||
| 	if _, ok := c.Get("specific_channel_id"); ok { | ||||
| 	if _, ok := c.Get(ctxkey.SpecificChannelId); ok { | ||||
| 		return false | ||||
| 	} | ||||
| 	if statusCode == http.StatusTooManyRequests { | ||||
|   | ||||
| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/network" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| @@ -13,7 +14,7 @@ import ( | ||||
| ) | ||||
|  | ||||
| func GetAllTokens(c *gin.Context) { | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	p, _ := strconv.Atoi(c.Query("p")) | ||||
| 	if p < 0 { | ||||
| 		p = 0 | ||||
| @@ -38,7 +39,7 @@ func GetAllTokens(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func SearchTokens(c *gin.Context) { | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	keyword := c.Query("keyword") | ||||
| 	tokens, err := model.SearchUserTokens(userId, keyword) | ||||
| 	if err != nil { | ||||
| @@ -58,7 +59,7 @@ func SearchTokens(c *gin.Context) { | ||||
|  | ||||
| func GetToken(c *gin.Context) { | ||||
| 	id, err := strconv.Atoi(c.Param("id")) | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -83,8 +84,8 @@ func GetToken(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetTokenStatus(c *gin.Context) { | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	token, err := model.GetTokenByIds(tokenId, userId) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -139,7 +140,7 @@ func AddToken(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	cleanToken := model.Token{ | ||||
| 		UserId:         c.GetInt("id"), | ||||
| 		UserId:         c.GetInt(ctxkey.Id), | ||||
| 		Name:           token.Name, | ||||
| 		Key:            random.GenerateKey(), | ||||
| 		CreatedTime:    helper.GetTimestamp(), | ||||
| @@ -168,7 +169,7 @@ func AddToken(c *gin.Context) { | ||||
|  | ||||
| func DeleteToken(c *gin.Context) { | ||||
| 	id, _ := strconv.Atoi(c.Param("id")) | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	err := model.DeleteTokenById(id, userId) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -185,7 +186,7 @@ func DeleteToken(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func UpdateToken(c *gin.Context) { | ||||
| 	userId := c.GetInt("id") | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	statusOnly := c.Query("status_only") | ||||
| 	token := model.Token{} | ||||
| 	err := c.ShouldBindJSON(&token) | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| @@ -238,7 +239,7 @@ func GetUser(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	myRole := c.GetInt("role") | ||||
| 	myRole := c.GetInt(ctxkey.Role) | ||||
| 	if myRole <= user.Role && myRole != model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -255,7 +256,7 @@ func GetUser(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetUserDashboard(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	id := c.GetInt(ctxkey.Id) | ||||
| 	now := time.Now() | ||||
| 	startOfDay := now.Truncate(24*time.Hour).AddDate(0, 0, -6).Unix() | ||||
| 	endOfDay := now.Truncate(24 * time.Hour).Add(24*time.Hour - time.Second).Unix() | ||||
| @@ -278,7 +279,7 @@ func GetUserDashboard(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GenerateAccessToken(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	id := c.GetInt(ctxkey.Id) | ||||
| 	user, err := model.GetUserById(id, true) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -314,7 +315,7 @@ func GenerateAccessToken(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetAffCode(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	id := c.GetInt(ctxkey.Id) | ||||
| 	user, err := model.GetUserById(id, true) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -342,7 +343,7 @@ func GetAffCode(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func GetSelf(c *gin.Context) { | ||||
| 	id := c.GetInt("id") | ||||
| 	id := c.GetInt(ctxkey.Id) | ||||
| 	user, err := model.GetUserById(id, false) | ||||
| 	if err != nil { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| @@ -387,7 +388,7 @@ func UpdateUser(c *gin.Context) { | ||||
| 		}) | ||||
| 		return | ||||
| 	} | ||||
| 	myRole := c.GetInt("role") | ||||
| 	myRole := c.GetInt(ctxkey.Role) | ||||
| 	if myRole <= originUser.Role && myRole != model.RoleRootUser { | ||||
| 		c.JSON(http.StatusOK, gin.H{ | ||||
| 			"success": false, | ||||
| @@ -445,7 +446,7 @@ func UpdateSelf(c *gin.Context) { | ||||
| 	} | ||||
|  | ||||
| 	cleanUser := model.User{ | ||||
| 		Id:          c.GetInt("id"), | ||||
| 		Id:          c.GetInt(ctxkey.Id), | ||||
| 		Username:    user.Username, | ||||
| 		Password:    user.Password, | ||||
| 		DisplayName: user.DisplayName, | ||||
|   | ||||
							
								
								
									
										2
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								main.go
									
									
									
									
									
								
							| @@ -71,7 +71,7 @@ func main() { | ||||
| 	} | ||||
| 	if config.MemoryCacheEnabled { | ||||
| 		logger.SysLog("memory cache enabled") | ||||
| 		logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) | ||||
| 		logger.SysLog(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) | ||||
| 		model.InitChannelCache() | ||||
| 	} | ||||
| 	if config.MemoryCacheEnabled { | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"github.com/gin-contrib/sessions" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/blacklist" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/network" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"net/http" | ||||
| @@ -120,20 +121,20 @@ func TokenAuth() func(c *gin.Context) { | ||||
| 			abortWithMessage(c, http.StatusBadRequest, err.Error()) | ||||
| 			return | ||||
| 		} | ||||
| 		c.Set("request_model", requestModel) | ||||
| 		c.Set(ctxkey.RequestModel, requestModel) | ||||
| 		if token.Models != nil && *token.Models != "" { | ||||
| 			c.Set("available_models", *token.Models) | ||||
| 			c.Set(ctxkey.AvailableModels, *token.Models) | ||||
| 			if requestModel != "" && !isModelInList(requestModel, *token.Models) { | ||||
| 				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		c.Set("id", token.UserId) | ||||
| 		c.Set("token_id", token.Id) | ||||
| 		c.Set("token_name", token.Name) | ||||
| 		c.Set(ctxkey.Id, token.UserId) | ||||
| 		c.Set(ctxkey.TokenId, token.Id) | ||||
| 		c.Set(ctxkey.TokenName, token.Name) | ||||
| 		if len(parts) > 1 { | ||||
| 			if model.IsAdmin(token.UserId) { | ||||
| 				c.Set("specific_channel_id", parts[1]) | ||||
| 				c.Set(ctxkey.SpecificChannelId, parts[1]) | ||||
| 			} else { | ||||
| 				abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") | ||||
| 				return | ||||
|   | ||||
| @@ -3,7 +3,6 @@ package middleware | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| @@ -18,12 +17,12 @@ type ModelRequest struct { | ||||
|  | ||||
| func Distribute() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		userId := c.GetInt("id") | ||||
| 		userId := c.GetInt(ctxkey.Id) | ||||
| 		userGroup, _ := model.CacheGetUserGroup(userId) | ||||
| 		c.Set("group", userGroup) | ||||
| 		c.Set(ctxkey.Group, userGroup) | ||||
| 		var requestModel string | ||||
| 		var channel *model.Channel | ||||
| 		channelId, ok := c.Get("specific_channel_id") | ||||
| 		channelId, ok := c.Get(ctxkey.SpecificChannelId) | ||||
| 		if ok { | ||||
| 			id, err := strconv.Atoi(channelId.(string)) | ||||
| 			if err != nil { | ||||
| @@ -40,7 +39,7 @@ func Distribute() func(c *gin.Context) { | ||||
| 				return | ||||
| 			} | ||||
| 		} else { | ||||
| 			requestModel = c.GetString("request_model") | ||||
| 			requestModel = c.GetString(ctxkey.RequestModel) | ||||
| 			var err error | ||||
| 			channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) | ||||
| 			if err != nil { | ||||
| @@ -59,28 +58,36 @@ func Distribute() func(c *gin.Context) { | ||||
| } | ||||
|  | ||||
| func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { | ||||
| 	c.Set("channel", channel.Type) | ||||
| 	c.Set("channel_id", channel.Id) | ||||
| 	c.Set("channel_name", channel.Name) | ||||
| 	c.Set("model_mapping", channel.GetModelMapping()) | ||||
| 	c.Set(ctxkey.Channel, channel.Type) | ||||
| 	c.Set(ctxkey.ChannelId, channel.Id) | ||||
| 	c.Set(ctxkey.ChannelName, channel.Name) | ||||
| 	c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) | ||||
| 	c.Set(ctxkey.OriginalModel, modelName) // for retry | ||||
| 	c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 	c.Set("base_url", channel.GetBaseURL()) | ||||
| 	c.Set(ctxkey.BaseURL, channel.GetBaseURL()) | ||||
| 	cfg, _ := channel.LoadConfig() | ||||
| 	// this is for backward compatibility | ||||
| 	switch channel.Type { | ||||
| 	case channeltype.Azure: | ||||
| 		c.Set(config.KeyAPIVersion, channel.Other) | ||||
| 		if cfg.APIVersion == "" { | ||||
| 			cfg.APIVersion = channel.Other | ||||
| 		} | ||||
| 	case channeltype.Xunfei: | ||||
| 		c.Set(config.KeyAPIVersion, channel.Other) | ||||
| 		if cfg.APIVersion == "" { | ||||
| 			cfg.APIVersion = channel.Other | ||||
| 		} | ||||
| 	case channeltype.Gemini: | ||||
| 		c.Set(config.KeyAPIVersion, channel.Other) | ||||
| 		if cfg.APIVersion == "" { | ||||
| 			cfg.APIVersion = channel.Other | ||||
| 		} | ||||
| 	case channeltype.AIProxyLibrary: | ||||
| 		c.Set(config.KeyLibraryID, channel.Other) | ||||
| 		if cfg.LibraryID == "" { | ||||
| 			cfg.LibraryID = channel.Other | ||||
| 		} | ||||
| 	case channeltype.Ali: | ||||
| 		c.Set(config.KeyPlugin, channel.Other) | ||||
| 	} | ||||
| 	cfg, _ := channel.LoadConfig() | ||||
| 	for k, v := range cfg { | ||||
| 		c.Set(config.KeyPrefix+k, v) | ||||
| 		if cfg.Plugin == "" { | ||||
| 			cfg.Plugin = channel.Other | ||||
| 		} | ||||
| 	} | ||||
| 	c.Set(ctxkey.Config, cfg) | ||||
| } | ||||
|   | ||||
| @@ -3,14 +3,14 @@ package middleware | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| ) | ||||
|  | ||||
| func SetUpLogger(server *gin.Engine) { | ||||
| 	server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { | ||||
| 		var requestID string | ||||
| 		if param.Keys != nil { | ||||
| 			requestID = param.Keys[logger.RequestIdKey].(string) | ||||
| 			requestID = param.Keys[helper.RequestIdKey].(string) | ||||
| 		} | ||||
| 		return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", | ||||
| 			param.TimeStamp.Format("2006/01/02 - 15:04:05"), | ||||
|   | ||||
| @@ -4,16 +4,15 @@ import ( | ||||
| 	"context" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| func RequestId() func(c *gin.Context) { | ||||
| 	return func(c *gin.Context) { | ||||
| 		id := helper.GenRequestID() | ||||
| 		c.Set(logger.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) | ||||
| 		c.Set(helper.RequestIdKey, id) | ||||
| 		ctx := context.WithValue(c.Request.Context(), helper.RequestIdKey, id) | ||||
| 		c.Request = c.Request.WithContext(ctx) | ||||
| 		c.Header(logger.RequestIdKey, id) | ||||
| 		c.Header(helper.RequestIdKey, id) | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|   | ||||
| @@ -12,7 +12,7 @@ import ( | ||||
| func abortWithMessage(c *gin.Context, statusCode int, message string) { | ||||
| 	c.JSON(statusCode, gin.H{ | ||||
| 		"error": gin.H{ | ||||
| 			"message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), | ||||
| 			"message": helper.MessageWithRequestId(message, c.GetString(helper.RequestIdKey)), | ||||
| 			"type":    "one_api_error", | ||||
| 		}, | ||||
| 	}) | ||||
|   | ||||
| @@ -38,6 +38,16 @@ type Channel struct { | ||||
| 	Config             string  `json:"config"` | ||||
| } | ||||
|  | ||||
| type ChannelConfig struct { | ||||
| 	Region     string `json:"region,omitempty"` | ||||
| 	SK         string `json:"sk,omitempty"` | ||||
| 	AK         string `json:"ak,omitempty"` | ||||
| 	UserID     string `json:"user_id,omitempty"` | ||||
| 	APIVersion string `json:"api_version,omitempty"` | ||||
| 	LibraryID  string `json:"library_id,omitempty"` | ||||
| 	Plugin     string `json:"plugin,omitempty"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { | ||||
| 	var channels []*Channel | ||||
| 	var err error | ||||
| @@ -161,14 +171,14 @@ func (channel *Channel) Delete() error { | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| func (channel *Channel) LoadConfig() (map[string]string, error) { | ||||
| func (channel *Channel) LoadConfig() (ChannelConfig, error) { | ||||
| 	var cfg ChannelConfig | ||||
| 	if channel.Config == "" { | ||||
| 		return nil, nil | ||||
| 		return cfg, nil | ||||
| 	} | ||||
| 	cfg := make(map[string]string) | ||||
| 	err := json.Unmarshal([]byte(channel.Config), &cfg) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 		return cfg, err | ||||
| 	} | ||||
| 	return cfg, nil | ||||
| } | ||||
|   | ||||
| @@ -7,6 +7,9 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/anthropic" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/aws" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/baidu" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/cloudflare" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/cohere" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/coze" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/gemini" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/ollama" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| @@ -43,6 +46,12 @@ func GetAdaptor(apiType int) adaptor.Adaptor { | ||||
| 		return &zhipu.Adaptor{} | ||||
| 	case apitype.Ollama: | ||||
| 		return &ollama.Adaptor{} | ||||
| 	case apitype.Coze: | ||||
| 		return &coze.Adaptor{} | ||||
| 	case apitype.Cohere: | ||||
| 		return &cohere.Adaptor{} | ||||
| 	case apitype.Cloudflare: | ||||
| 		return &cloudflare.Adaptor{} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -4,7 +4,6 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -13,10 +12,11 @@ import ( | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| @@ -34,7 +34,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	aiProxyLibraryRequest := ConvertRequest(*request) | ||||
| 	aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID) | ||||
| 	aiProxyLibraryRequest.LibraryId = a.meta.Config.LibraryID | ||||
| 	return aiProxyLibraryRequest, nil | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -4,7 +4,6 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| @@ -16,10 +15,11 @@ import ( | ||||
| // https://help.aliyun.com/zh/dashscope/developer-reference/api-details | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| @@ -47,8 +47,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me | ||||
| 	if meta.Mode == relaymode.ImagesGenerations { | ||||
| 		req.Header.Set("X-DashScope-Async", "enable") | ||||
| 	} | ||||
| 	if c.GetString(config.KeyPlugin) != "" { | ||||
| 		req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin)) | ||||
| 	if a.meta.Config.Plugin != "" { | ||||
| 		req.Header.Set("X-DashScope-Plugin", a.meta.Config.Plugin) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| package aws | ||||
|  | ||||
| import ( | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -16,10 +19,16 @@ import ( | ||||
| var _ adaptor.Adaptor = new(Adaptor) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta      *meta.Meta | ||||
| 	awsClient *bedrockruntime.Client | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| 	a.meta = meta | ||||
| 	a.awsClient = bedrockruntime.New(bedrockruntime.Options{ | ||||
| 		Region:      meta.Config.Region, | ||||
| 		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(meta.Config.AK, meta.Config.SK, "")), | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| @@ -54,9 +63,9 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Read | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 		err, usage = StreamHandler(c, a.awsClient) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 		err, usage = Handler(c, a.awsClient, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
| @@ -65,7 +74,6 @@ func (a *Adaptor) GetModelList() (models []string) { | ||||
| 	for n := range awsModelIDMap { | ||||
| 		models = append(models, n) | ||||
| 	} | ||||
|  | ||||
| 	return | ||||
| } | ||||
|  | ||||
|   | ||||
| @@ -5,13 +5,11 @@ import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/aws/aws-sdk-go-v2/aws" | ||||
| 	"github.com/aws/aws-sdk-go-v2/credentials" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime" | ||||
| 	"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| @@ -24,18 +22,6 @@ import ( | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func newAwsClient(c *gin.Context) (*bedrockruntime.Client, error) { | ||||
| 	ak := c.GetString(config.KeyAK) | ||||
| 	sk := c.GetString(config.KeySK) | ||||
| 	region := c.GetString(config.KeyRegion) | ||||
| 	client := bedrockruntime.New(bedrockruntime.Options{ | ||||
| 		Region:      region, | ||||
| 		Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), | ||||
| 	}) | ||||
|  | ||||
| 	return client, nil | ||||
| } | ||||
|  | ||||
| func wrapErr(err error) *relaymodel.ErrorWithStatusCode { | ||||
| 	return &relaymodel.ErrorWithStatusCode{ | ||||
| 		StatusCode: http.StatusInternalServerError, | ||||
| @@ -63,12 +49,7 @@ func awsModelID(requestModel string) (string, error) { | ||||
| 	return "", errors.Errorf("model %s not found", requestModel) | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	awsCli, err := newAwsClient(c) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "newAwsClient")), nil | ||||
| 	} | ||||
|  | ||||
| func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
| @@ -121,13 +102,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	awsCli, err := newAwsClient(c) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "newAwsClient")), nil | ||||
| 	} | ||||
|  | ||||
| 	awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) | ||||
| 	if err != nil { | ||||
| 		return wrapErr(errors.Wrap(err, "awsModelID")), nil | ||||
|   | ||||
| @@ -1,15 +0,0 @@ | ||||
| package azure | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| ) | ||||
|  | ||||
| func GetAPIVersion(c *gin.Context) string { | ||||
| 	query := c.Request.URL.Query() | ||||
| 	apiVersion := query.Get("api-version") | ||||
| 	if apiVersion == "" { | ||||
| 		apiVersion = c.GetString(config.KeyAPIVersion) | ||||
| 	} | ||||
| 	return apiVersion | ||||
| } | ||||
							
								
								
									
										66
									
								
								relay/adaptor/cloudflare/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								relay/adaptor/cloudflare/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| package cloudflare | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
| func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	return nil, errors.New("not implemented") | ||||
| } | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", meta.BaseURL, meta.Config.UserID, meta.ActualModelName), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "cloudflare" | ||||
| } | ||||
							
								
								
									
										36
									
								
								relay/adaptor/cloudflare/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								relay/adaptor/cloudflare/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package cloudflare | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"@cf/meta/llama-2-7b-chat-fp16", | ||||
| 	"@cf/meta/llama-2-7b-chat-int8", | ||||
| 	"@cf/mistral/mistral-7b-instruct-v0.1", | ||||
| 	"@hf/thebloke/deepseek-coder-6.7b-base-awq", | ||||
| 	"@hf/thebloke/deepseek-coder-6.7b-instruct-awq", | ||||
| 	"@cf/deepseek-ai/deepseek-math-7b-base", | ||||
| 	"@cf/deepseek-ai/deepseek-math-7b-instruct", | ||||
| 	"@cf/thebloke/discolm-german-7b-v1-awq", | ||||
| 	"@cf/tiiuae/falcon-7b-instruct", | ||||
| 	"@cf/google/gemma-2b-it-lora", | ||||
| 	"@hf/google/gemma-7b-it", | ||||
| 	"@cf/google/gemma-7b-it-lora", | ||||
| 	"@hf/nousresearch/hermes-2-pro-mistral-7b", | ||||
| 	"@hf/thebloke/llama-2-13b-chat-awq", | ||||
| 	"@cf/meta-llama/llama-2-7b-chat-hf-lora", | ||||
| 	"@cf/meta/llama-3-8b-instruct", | ||||
| 	"@hf/thebloke/llamaguard-7b-awq", | ||||
| 	"@hf/thebloke/mistral-7b-instruct-v0.1-awq", | ||||
| 	"@hf/mistralai/mistral-7b-instruct-v0.2", | ||||
| 	"@cf/mistral/mistral-7b-instruct-v0.2-lora", | ||||
| 	"@hf/thebloke/neural-chat-7b-v3-1-awq", | ||||
| 	"@cf/openchat/openchat-3.5-0106", | ||||
| 	"@hf/thebloke/openhermes-2.5-mistral-7b-awq", | ||||
| 	"@cf/microsoft/phi-2", | ||||
| 	"@cf/qwen/qwen1.5-0.5b-chat", | ||||
| 	"@cf/qwen/qwen1.5-1.8b-chat", | ||||
| 	"@cf/qwen/qwen1.5-14b-chat-awq", | ||||
| 	"@cf/qwen/qwen1.5-7b-chat-awq", | ||||
| 	"@cf/defog/sqlcoder-7b-2", | ||||
| 	"@hf/nexusflow/starling-lm-7b-beta", | ||||
| 	"@cf/tinyllama/tinyllama-1.1b-chat-v1.0", | ||||
| 	"@hf/thebloke/zephyr-7b-beta-awq", | ||||
| } | ||||
							
								
								
									
										152
									
								
								relay/adaptor/cloudflare/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										152
									
								
								relay/adaptor/cloudflare/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,152 @@ | ||||
| package cloudflare | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	lastMessage := textRequest.Messages[len(textRequest.Messages)-1] | ||||
| 	return &Request{ | ||||
| 		MaxTokens:   textRequest.MaxTokens, | ||||
| 		Prompt:      lastMessage.StringContent(), | ||||
| 		Stream:      textRequest.Stream, | ||||
| 		Temperature: textRequest.Temperature, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ResponseCloudflare2OpenAI(cloudflareResponse *Response) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: cloudflareResponse.Result.Response, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamResponseCloudflare2OpenAI(cloudflareResponse *StreamResponse) *openai.ChatCompletionsStreamResponse { | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = cloudflareResponse.Response | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	openaiResponse := openai.ChatCompletionsStreamResponse{ | ||||
| 		Object:  "chat.completion.chunk", | ||||
| 		Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 	} | ||||
| 	return &openaiResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := bytes.IndexByte(data, '\n'); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
|  | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < len("data: ") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data: ") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	id := helper.GetResponseID(c) | ||||
| 	responseModel := c.GetString("original_model") | ||||
| 	var responseText string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var cloudflareResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &cloudflareResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response := StreamResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			responseText += cloudflareResponse.Response | ||||
| 			response.Id = id | ||||
| 			response.Model = responseModel | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 	usage := openai.ResponseText2Usage(responseText, responseModel, promptTokens) | ||||
| 	return nil, usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var cloudflareResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &cloudflareResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	fullTextResponse := ResponseCloudflare2OpenAI(&cloudflareResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	usage := openai.ResponseText2Usage(cloudflareResponse.Result.Response, modelName, promptTokens) | ||||
| 	fullTextResponse.Usage = *usage | ||||
| 	fullTextResponse.Id = helper.GetResponseID(c) | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, usage | ||||
| } | ||||
							
								
								
									
										25
									
								
								relay/adaptor/cloudflare/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								relay/adaptor/cloudflare/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,25 @@ | ||||
| package cloudflare | ||||
|  | ||||
| type Request struct { | ||||
| 	Lora        string  `json:"lora,omitempty"` | ||||
| 	MaxTokens   int     `json:"max_tokens,omitempty"` | ||||
| 	Prompt      string  `json:"prompt,omitempty"` | ||||
| 	Raw         bool    `json:"raw,omitempty"` | ||||
| 	Stream      bool    `json:"stream,omitempty"` | ||||
| 	Temperature float64 `json:"temperature,omitempty"` | ||||
| } | ||||
|  | ||||
| type Result struct { | ||||
| 	Response string `json:"response"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	Result   Result   `json:"result"` | ||||
| 	Success  bool     `json:"success"` | ||||
| 	Errors   []string `json:"errors"` | ||||
| 	Messages []string `json:"messages"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	Response string `json:"response"` | ||||
| } | ||||
							
								
								
									
										64
									
								
								relay/adaptor/cohere/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								relay/adaptor/cohere/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,64 @@ | ||||
| package cohere | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct{} | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
| func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	return nil, errors.New("not implemented") | ||||
| } | ||||
|  | ||||
| // ConvertImageRequest implements adaptor.Adaptor. | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/v1/chat", meta.BaseURL), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "Cohere" | ||||
| } | ||||
							
								
								
									
										14
									
								
								relay/adaptor/cohere/constant.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								relay/adaptor/cohere/constant.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,14 @@ | ||||
| package cohere | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"command", "command-nightly", | ||||
| 	"command-light", "command-light-nightly", | ||||
| 	"command-r", "command-r-plus", | ||||
| } | ||||
|  | ||||
| func init() { | ||||
| 	num := len(ModelList) | ||||
| 	for i := 0; i < num; i++ { | ||||
| 		ModelList = append(ModelList, ModelList[i]+"-internet") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										241
									
								
								relay/adaptor/cohere/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										241
									
								
								relay/adaptor/cohere/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,241 @@ | ||||
| package cohere | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	WebSearchConnector = Connector{ID: "web-search"} | ||||
| ) | ||||
|  | ||||
| func stopReasonCohere2OpenAI(reason *string) string { | ||||
| 	if reason == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	switch *reason { | ||||
| 	case "COMPLETE": | ||||
| 		return "stop" | ||||
| 	default: | ||||
| 		return *reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	cohereRequest := Request{ | ||||
| 		Model:            textRequest.Model, | ||||
| 		Message:          "", | ||||
| 		MaxTokens:        textRequest.MaxTokens, | ||||
| 		Temperature:      textRequest.Temperature, | ||||
| 		P:                textRequest.TopP, | ||||
| 		K:                textRequest.TopK, | ||||
| 		Stream:           textRequest.Stream, | ||||
| 		FrequencyPenalty: textRequest.FrequencyPenalty, | ||||
| 		PresencePenalty:  textRequest.FrequencyPenalty, | ||||
| 		Seed:             int(textRequest.Seed), | ||||
| 	} | ||||
| 	if cohereRequest.Model == "" { | ||||
| 		cohereRequest.Model = "command-r" | ||||
| 	} | ||||
| 	if strings.HasSuffix(cohereRequest.Model, "-internet") { | ||||
| 		cohereRequest.Model = strings.TrimSuffix(cohereRequest.Model, "-internet") | ||||
| 		cohereRequest.Connectors = append(cohereRequest.Connectors, WebSearchConnector) | ||||
| 	} | ||||
| 	for _, message := range textRequest.Messages { | ||||
| 		if message.Role == "user" { | ||||
| 			cohereRequest.Message = message.Content.(string) | ||||
| 		} else { | ||||
| 			var role string | ||||
| 			if message.Role == "assistant" { | ||||
| 				role = "CHATBOT" | ||||
| 			} else if message.Role == "system" { | ||||
| 				role = "SYSTEM" | ||||
| 			} else { | ||||
| 				role = "USER" | ||||
| 			} | ||||
| 			cohereRequest.ChatHistory = append(cohereRequest.ChatHistory, ChatMessage{ | ||||
| 				Role:    role, | ||||
| 				Message: message.Content.(string), | ||||
| 			}) | ||||
| 		} | ||||
| 	} | ||||
| 	return &cohereRequest | ||||
| } | ||||
|  | ||||
| func StreamResponseCohere2OpenAI(cohereResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { | ||||
| 	var response *Response | ||||
| 	var responseText string | ||||
| 	var finishReason string | ||||
|  | ||||
| 	switch cohereResponse.EventType { | ||||
| 	case "stream-start": | ||||
| 		return nil, nil | ||||
| 	case "text-generation": | ||||
| 		responseText += cohereResponse.Text | ||||
| 	case "stream-end": | ||||
| 		usage := cohereResponse.Response.Meta.Tokens | ||||
| 		response = &Response{ | ||||
| 			Meta: Meta{ | ||||
| 				Tokens: Usage{ | ||||
| 					InputTokens:  usage.InputTokens, | ||||
| 					OutputTokens: usage.OutputTokens, | ||||
| 				}, | ||||
| 			}, | ||||
| 		} | ||||
| 		finishReason = *cohereResponse.Response.FinishReason | ||||
| 	default: | ||||
| 		return nil, nil | ||||
| 	} | ||||
|  | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
| 	choice.Delta.Content = responseText | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	if finishReason != "" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||
| 	openaiResponse.Object = "chat.completion.chunk" | ||||
| 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	return &openaiResponse, response | ||||
| } | ||||
|  | ||||
| func ResponseCohere2OpenAI(cohereResponse *Response) *openai.TextResponse { | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: cohereResponse.Text, | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: stopReasonCohere2OpenAI(cohereResponse.FinishReason), | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", cohereResponse.ResponseID), | ||||
| 		Model:   "model", | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := bytes.IndexByte(data, '\n'); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
|  | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var usage model.Usage | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var cohereResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &cohereResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, meta := StreamResponseCohere2OpenAI(&cohereResponse) | ||||
| 			if meta != nil { | ||||
| 				usage.PromptTokens += meta.Meta.Tokens.InputTokens | ||||
| 				usage.CompletionTokens += meta.Meta.Tokens.OutputTokens | ||||
| 				return true | ||||
| 			} | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			response.Id = fmt.Sprintf("chatcmpl-%d", createdTime) | ||||
| 			response.Model = c.GetString("original_model") | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var cohereResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &cohereResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if cohereResponse.ResponseID == "" { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: cohereResponse.Message, | ||||
| 				Type:    cohereResponse.Message, | ||||
| 				Param:   "", | ||||
| 				Code:    resp.StatusCode, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := ResponseCohere2OpenAI(&cohereResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	usage := model.Usage{ | ||||
| 		PromptTokens:     cohereResponse.Meta.Tokens.InputTokens, | ||||
| 		CompletionTokens: cohereResponse.Meta.Tokens.OutputTokens, | ||||
| 		TotalTokens:      cohereResponse.Meta.Tokens.InputTokens + cohereResponse.Meta.Tokens.OutputTokens, | ||||
| 	} | ||||
| 	fullTextResponse.Usage = usage | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	return nil, &usage | ||||
| } | ||||
							
								
								
									
										147
									
								
								relay/adaptor/cohere/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										147
									
								
								relay/adaptor/cohere/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,147 @@ | ||||
| package cohere | ||||
|  | ||||
| type Request struct { | ||||
| 	Message          string        `json:"message" required:"true"` | ||||
| 	Model            string        `json:"model,omitempty"`  // 默认值为"command-r" | ||||
| 	Stream           bool          `json:"stream,omitempty"` // 默认值为false | ||||
| 	Preamble         string        `json:"preamble,omitempty"` | ||||
| 	ChatHistory      []ChatMessage `json:"chat_history,omitempty"` | ||||
| 	ConversationID   string        `json:"conversation_id,omitempty"` | ||||
| 	PromptTruncation string        `json:"prompt_truncation,omitempty"` // 默认值为"AUTO" | ||||
| 	Connectors       []Connector   `json:"connectors,omitempty"` | ||||
| 	Documents        []Document    `json:"documents,omitempty"` | ||||
| 	Temperature      float64       `json:"temperature,omitempty"` // 默认值为0.3 | ||||
| 	MaxTokens        int           `json:"max_tokens,omitempty"` | ||||
| 	MaxInputTokens   int           `json:"max_input_tokens,omitempty"` | ||||
| 	K                int           `json:"k,omitempty"` // 默认值为0 | ||||
| 	P                float64       `json:"p,omitempty"` // 默认值为0.75 | ||||
| 	Seed             int           `json:"seed,omitempty"` | ||||
| 	StopSequences    []string      `json:"stop_sequences,omitempty"` | ||||
| 	FrequencyPenalty float64       `json:"frequency_penalty,omitempty"` // 默认值为0.0 | ||||
| 	PresencePenalty  float64       `json:"presence_penalty,omitempty"`  // 默认值为0.0 | ||||
| 	Tools            []Tool        `json:"tools,omitempty"` | ||||
| 	ToolResults      []ToolResult  `json:"tool_results,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatMessage struct { | ||||
| 	Role    string `json:"role" required:"true"` | ||||
| 	Message string `json:"message" required:"true"` | ||||
| } | ||||
|  | ||||
| type Tool struct { | ||||
| 	Name                 string                   `json:"name" required:"true"` | ||||
| 	Description          string                   `json:"description" required:"true"` | ||||
| 	ParameterDefinitions map[string]ParameterSpec `json:"parameter_definitions"` | ||||
| } | ||||
|  | ||||
| type ParameterSpec struct { | ||||
| 	Description string `json:"description"` | ||||
| 	Type        string `json:"type" required:"true"` | ||||
| 	Required    bool   `json:"required"` | ||||
| } | ||||
|  | ||||
| type ToolResult struct { | ||||
| 	Call    ToolCall                 `json:"call"` | ||||
| 	Outputs []map[string]interface{} `json:"outputs"` | ||||
| } | ||||
|  | ||||
| type ToolCall struct { | ||||
| 	Name       string                 `json:"name" required:"true"` | ||||
| 	Parameters map[string]interface{} `json:"parameters" required:"true"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	IsFinished    bool            `json:"is_finished"` | ||||
| 	EventType     string          `json:"event_type"` | ||||
| 	GenerationID  string          `json:"generation_id,omitempty"` | ||||
| 	SearchQueries []*SearchQuery  `json:"search_queries,omitempty"` | ||||
| 	SearchResults []*SearchResult `json:"search_results,omitempty"` | ||||
| 	Documents     []*Document     `json:"documents,omitempty"` | ||||
| 	Text          string          `json:"text,omitempty"` | ||||
| 	Citations     []*Citation     `json:"citations,omitempty"` | ||||
| 	Response      *Response       `json:"response,omitempty"` | ||||
| 	FinishReason  string          `json:"finish_reason,omitempty"` | ||||
| } | ||||
|  | ||||
| type SearchQuery struct { | ||||
| 	Text         string `json:"text"` | ||||
| 	GenerationID string `json:"generation_id"` | ||||
| } | ||||
|  | ||||
| type SearchResult struct { | ||||
| 	SearchQuery *SearchQuery `json:"search_query"` | ||||
| 	DocumentIDs []string     `json:"document_ids"` | ||||
| 	Connector   *Connector   `json:"connector"` | ||||
| } | ||||
|  | ||||
| type Connector struct { | ||||
| 	ID string `json:"id"` | ||||
| } | ||||
|  | ||||
| type Document struct { | ||||
| 	ID        string `json:"id"` | ||||
| 	Snippet   string `json:"snippet"` | ||||
| 	Timestamp string `json:"timestamp"` | ||||
| 	Title     string `json:"title"` | ||||
| 	URL       string `json:"url"` | ||||
| } | ||||
|  | ||||
| type Citation struct { | ||||
| 	Start       int      `json:"start"` | ||||
| 	End         int      `json:"end"` | ||||
| 	Text        string   `json:"text"` | ||||
| 	DocumentIDs []string `json:"document_ids"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	ResponseID    string          `json:"response_id"` | ||||
| 	Text          string          `json:"text"` | ||||
| 	GenerationID  string          `json:"generation_id"` | ||||
| 	ChatHistory   []*Message      `json:"chat_history"` | ||||
| 	FinishReason  *string         `json:"finish_reason"` | ||||
| 	Meta          Meta            `json:"meta"` | ||||
| 	Citations     []*Citation     `json:"citations"` | ||||
| 	Documents     []*Document     `json:"documents"` | ||||
| 	SearchResults []*SearchResult `json:"search_results"` | ||||
| 	SearchQueries []*SearchQuery  `json:"search_queries"` | ||||
| 	Message       string          `json:"message"` | ||||
| } | ||||
|  | ||||
| type Message struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type Version struct { | ||||
| 	Version string `json:"version"` | ||||
| } | ||||
|  | ||||
| type Units struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type ChatEntry struct { | ||||
| 	Role    string `json:"role"` | ||||
| 	Message string `json:"message"` | ||||
| } | ||||
|  | ||||
| type Meta struct { | ||||
| 	APIVersion  APIVersion  `json:"api_version"` | ||||
| 	BilledUnits BilledUnits `json:"billed_units"` | ||||
| 	Tokens      Usage       `json:"tokens"` | ||||
| } | ||||
|  | ||||
| type APIVersion struct { | ||||
| 	Version string `json:"version"` | ||||
| } | ||||
|  | ||||
| type BilledUnits struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
|  | ||||
| type Usage struct { | ||||
| 	InputTokens  int `json:"input_tokens"` | ||||
| 	OutputTokens int `json:"output_tokens"` | ||||
| } | ||||
							
								
								
									
										75
									
								
								relay/adaptor/coze/adaptor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								relay/adaptor/coze/adaptor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,75 @@ | ||||
| package coze | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| ) | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	meta *meta.Meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	return fmt.Sprintf("%s/open_api/v2/chat", meta.BaseURL), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	req.Header.Set("Authorization", "Bearer "+meta.APIKey) | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	request.User = a.meta.Config.UserID | ||||
| 	return ConvertRequest(*request), nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { | ||||
| 	if request == nil { | ||||
| 		return nil, errors.New("request is nil") | ||||
| 	} | ||||
| 	return request, nil | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { | ||||
| 	return adaptor.DoRequestHelper(a, c, meta, requestBody) | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { | ||||
| 	var responseText *string | ||||
| 	if meta.IsStream { | ||||
| 		err, responseText = StreamHandler(c, resp) | ||||
| 	} else { | ||||
| 		err, responseText = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) | ||||
| 	} | ||||
| 	if responseText != nil { | ||||
| 		usage = openai.ResponseText2Usage(*responseText, meta.ActualModelName, meta.PromptTokens) | ||||
| 	} else { | ||||
| 		usage = &model.Usage{} | ||||
| 	} | ||||
| 	usage.PromptTokens = meta.PromptTokens | ||||
| 	usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens | ||||
| 	return | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetModelList() []string { | ||||
| 	return ModelList | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetChannelName() string { | ||||
| 	return "coze" | ||||
| } | ||||
							
								
								
									
										5
									
								
								relay/adaptor/coze/constant/contenttype/define.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								relay/adaptor/coze/constant/contenttype/define.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| package contenttype | ||||
|  | ||||
| const ( | ||||
| 	Text = "text" | ||||
| ) | ||||
							
								
								
									
										7
									
								
								relay/adaptor/coze/constant/event/define.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								relay/adaptor/coze/constant/event/define.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package event | ||||
|  | ||||
| const ( | ||||
| 	Message = "message" | ||||
| 	Done    = "done" | ||||
| 	Error   = "error" | ||||
| ) | ||||
							
								
								
									
										6
									
								
								relay/adaptor/coze/constant/messagetype/define.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								relay/adaptor/coze/constant/messagetype/define.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package messagetype | ||||
|  | ||||
| const ( | ||||
| 	Answer   = "answer" | ||||
| 	FollowUp = "follow_up" | ||||
| ) | ||||
							
								
								
									
										3
									
								
								relay/adaptor/coze/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								relay/adaptor/coze/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| package coze | ||||
|  | ||||
| var ModelList = []string{} | ||||
							
								
								
									
										10
									
								
								relay/adaptor/coze/helper.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								relay/adaptor/coze/helper.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package coze | ||||
|  | ||||
| import "github.com/songquanpeng/one-api/relay/adaptor/coze/constant/event" | ||||
|  | ||||
| func event2StopReason(e *string) string { | ||||
| 	if e == nil || *e == event.Message { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return "stop" | ||||
| } | ||||
							
								
								
									
										215
									
								
								relay/adaptor/coze/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										215
									
								
								relay/adaptor/coze/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,215 @@ | ||||
| package coze | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/conv" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/coze/constant/messagetype" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // https://www.coze.com/open | ||||
|  | ||||
| func stopReasonCoze2OpenAI(reason *string) string { | ||||
| 	if reason == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	switch *reason { | ||||
| 	case "end_turn": | ||||
| 		return "stop" | ||||
| 	case "stop_sequence": | ||||
| 		return "stop" | ||||
| 	case "max_tokens": | ||||
| 		return "length" | ||||
| 	default: | ||||
| 		return *reason | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { | ||||
| 	cozeRequest := Request{ | ||||
| 		Stream: textRequest.Stream, | ||||
| 		User:   textRequest.User, | ||||
| 		BotId:  strings.TrimPrefix(textRequest.Model, "bot-"), | ||||
| 	} | ||||
| 	for i, message := range textRequest.Messages { | ||||
| 		if i == len(textRequest.Messages)-1 { | ||||
| 			cozeRequest.Query = message.StringContent() | ||||
| 			continue | ||||
| 		} | ||||
| 		cozeMessage := Message{ | ||||
| 			Role:    message.Role, | ||||
| 			Content: message.StringContent(), | ||||
| 		} | ||||
| 		cozeRequest.ChatHistory = append(cozeRequest.ChatHistory, cozeMessage) | ||||
| 	} | ||||
| 	return &cozeRequest | ||||
| } | ||||
|  | ||||
| func StreamResponseCoze2OpenAI(cozeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { | ||||
| 	var response *Response | ||||
| 	var stopReason string | ||||
| 	var choice openai.ChatCompletionsStreamResponseChoice | ||||
|  | ||||
| 	if cozeResponse.Message != nil { | ||||
| 		if cozeResponse.Message.Type != messagetype.Answer { | ||||
| 			return nil, nil | ||||
| 		} | ||||
| 		choice.Delta.Content = cozeResponse.Message.Content | ||||
| 	} | ||||
| 	choice.Delta.Role = "assistant" | ||||
| 	finishReason := stopReasonCoze2OpenAI(&stopReason) | ||||
| 	if finishReason != "null" { | ||||
| 		choice.FinishReason = &finishReason | ||||
| 	} | ||||
| 	var openaiResponse openai.ChatCompletionsStreamResponse | ||||
| 	openaiResponse.Object = "chat.completion.chunk" | ||||
| 	openaiResponse.Choices = []openai.ChatCompletionsStreamResponseChoice{choice} | ||||
| 	openaiResponse.Id = cozeResponse.ConversationId | ||||
| 	return &openaiResponse, response | ||||
| } | ||||
|  | ||||
| func ResponseCoze2OpenAI(cozeResponse *Response) *openai.TextResponse { | ||||
| 	var responseText string | ||||
| 	for _, message := range cozeResponse.Messages { | ||||
| 		if message.Type == messagetype.Answer { | ||||
| 			responseText = message.Content | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	choice := openai.TextResponseChoice{ | ||||
| 		Index: 0, | ||||
| 		Message: model.Message{ | ||||
| 			Role:    "assistant", | ||||
| 			Content: responseText, | ||||
| 			Name:    nil, | ||||
| 		}, | ||||
| 		FinishReason: "stop", | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", cozeResponse.ConversationId), | ||||
| 		Model:   "coze-bot", | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
| 	} | ||||
| 	return &fullTextResponse | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *string) { | ||||
| 	var responseText string | ||||
| 	createdTime := helper.GetTimestamp() | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| 	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { | ||||
| 		if atEOF && len(data) == 0 { | ||||
| 			return 0, nil, nil | ||||
| 		} | ||||
| 		if i := strings.Index(string(data), "\n"); i >= 0 { | ||||
| 			return i + 1, data[0:i], nil | ||||
| 		} | ||||
| 		if atEOF { | ||||
| 			return len(data), data, nil | ||||
| 		} | ||||
| 		return 0, nil, nil | ||||
| 	}) | ||||
| 	dataChan := make(chan string) | ||||
| 	stopChan := make(chan bool) | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 5 { | ||||
| 				continue | ||||
| 			} | ||||
| 			if !strings.HasPrefix(data, "data:") { | ||||
| 				continue | ||||
| 			} | ||||
| 			data = strings.TrimPrefix(data, "data:") | ||||
| 			dataChan <- data | ||||
| 		} | ||||
| 		stopChan <- true | ||||
| 	}() | ||||
| 	common.SetEventStreamHeaders(c) | ||||
| 	var modelName string | ||||
| 	c.Stream(func(w io.Writer) bool { | ||||
| 		select { | ||||
| 		case data := <-dataChan: | ||||
| 			// some implementations may add \r at the end of data | ||||
| 			data = strings.TrimSuffix(data, "\r") | ||||
| 			var cozeResponse StreamResponse | ||||
| 			err := json.Unmarshal([]byte(data), &cozeResponse) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			response, _ := StreamResponseCoze2OpenAI(&cozeResponse) | ||||
| 			if response == nil { | ||||
| 				return true | ||||
| 			} | ||||
| 			for _, choice := range response.Choices { | ||||
| 				responseText += conv.AsString(choice.Delta.Content) | ||||
| 			} | ||||
| 			response.Model = modelName | ||||
| 			response.Created = createdTime | ||||
| 			jsonStr, err := json.Marshal(response) | ||||
| 			if err != nil { | ||||
| 				logger.SysError("error marshalling stream response: " + err.Error()) | ||||
| 				return true | ||||
| 			} | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) | ||||
| 			return true | ||||
| 		case <-stopChan: | ||||
| 			c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||||
| 			return false | ||||
| 		} | ||||
| 	}) | ||||
| 	_ = resp.Body.Close() | ||||
| 	return nil, &responseText | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *string) { | ||||
| 	responseBody, err := io.ReadAll(resp.Body) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	err = resp.Body.Close() | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	var cozeResponse Response | ||||
| 	err = json.Unmarshal(responseBody, &cozeResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	if cozeResponse.Code != 0 { | ||||
| 		return &model.ErrorWithStatusCode{ | ||||
| 			Error: model.Error{ | ||||
| 				Message: cozeResponse.Msg, | ||||
| 				Code:    cozeResponse.Code, | ||||
| 			}, | ||||
| 			StatusCode: resp.StatusCode, | ||||
| 		}, nil | ||||
| 	} | ||||
| 	fullTextResponse := ResponseCoze2OpenAI(&cozeResponse) | ||||
| 	fullTextResponse.Model = modelName | ||||
| 	jsonResponse, err := json.Marshal(fullTextResponse) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||||
| 	} | ||||
| 	c.Writer.Header().Set("Content-Type", "application/json") | ||||
| 	c.Writer.WriteHeader(resp.StatusCode) | ||||
| 	_, err = c.Writer.Write(jsonResponse) | ||||
| 	var responseText string | ||||
| 	if len(fullTextResponse.Choices) > 0 { | ||||
| 		responseText = fullTextResponse.Choices[0].Message.StringContent() | ||||
| 	} | ||||
| 	return nil, &responseText | ||||
| } | ||||
							
								
								
									
										38
									
								
								relay/adaptor/coze/model.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								relay/adaptor/coze/model.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,38 @@ | ||||
| package coze | ||||
|  | ||||
| type Message struct { | ||||
| 	Role        string `json:"role"` | ||||
| 	Type        string `json:"type"` | ||||
| 	Content     string `json:"content"` | ||||
| 	ContentType string `json:"content_type"` | ||||
| } | ||||
|  | ||||
| type ErrorInformation struct { | ||||
| 	Code int    `json:"code"` | ||||
| 	Msg  string `json:"msg"` | ||||
| } | ||||
|  | ||||
| type Request struct { | ||||
| 	ConversationId string    `json:"conversation_id,omitempty"` | ||||
| 	BotId          string    `json:"bot_id"` | ||||
| 	User           string    `json:"user"` | ||||
| 	Query          string    `json:"query"` | ||||
| 	ChatHistory    []Message `json:"chat_history,omitempty"` | ||||
| 	Stream         bool      `json:"stream"` | ||||
| } | ||||
|  | ||||
| type Response struct { | ||||
| 	ConversationId string    `json:"conversation_id,omitempty"` | ||||
| 	Messages       []Message `json:"messages,omitempty"` | ||||
| 	Code           int       `json:"code,omitempty"` | ||||
| 	Msg            string    `json:"msg,omitempty"` | ||||
| } | ||||
|  | ||||
| type StreamResponse struct { | ||||
| 	Event            string            `json:"event,omitempty"` | ||||
| 	Message          *Message          `json:"message,omitempty"` | ||||
| 	IsFinish         bool              `json:"is_finish,omitempty"` | ||||
| 	Index            int               `json:"index,omitempty"` | ||||
| 	ConversationId   string            `json:"conversation_id,omitempty"` | ||||
| 	ErrorInformation *ErrorInformation `json:"error_information,omitempty"` | ||||
| } | ||||
							
								
								
									
										6
									
								
								relay/adaptor/deepseek/constants.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								relay/adaptor/deepseek/constants.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package deepseek | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"deepseek-chat", | ||||
| 	"deepseek-coder", | ||||
| } | ||||
| @@ -22,7 +22,7 @@ func (a *Adaptor) Init(meta *meta.Meta) { | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 	version := helper.AssignOrDefault(meta.APIVersion, config.GeminiVersion) | ||||
| 	version := helper.AssignOrDefault(meta.Config.APIVersion, config.GeminiVersion) | ||||
| 	action := "generateContent" | ||||
| 	if meta.IsStream { | ||||
| 		action = "streamGenerateContent" | ||||
|   | ||||
| @@ -4,6 +4,10 @@ import ( | ||||
| 	"bufio" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| @@ -13,9 +17,6 @@ import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| ) | ||||
| @@ -54,7 +55,17 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { | ||||
| 			MaxOutputTokens: textRequest.MaxTokens, | ||||
| 		}, | ||||
| 	} | ||||
| 	if textRequest.Functions != nil { | ||||
| 	if textRequest.Tools != nil { | ||||
| 		functions := make([]model.Function, 0, len(textRequest.Tools)) | ||||
| 		for _, tool := range textRequest.Tools { | ||||
| 			functions = append(functions, tool.Function) | ||||
| 		} | ||||
| 		geminiRequest.Tools = []ChatTools{ | ||||
| 			{ | ||||
| 				FunctionDeclarations: functions, | ||||
| 			}, | ||||
| 		} | ||||
| 	} else if textRequest.Functions != nil { | ||||
| 		geminiRequest.Tools = []ChatTools{ | ||||
| 			{ | ||||
| 				FunctionDeclarations: textRequest.Functions, | ||||
| @@ -154,6 +165,30 @@ type ChatPromptFeedback struct { | ||||
| 	SafetyRatings []ChatSafetyRating `json:"safetyRatings"` | ||||
| } | ||||
|  | ||||
| func getToolCalls(candidate *ChatCandidate) []model.Tool { | ||||
| 	var toolCalls []model.Tool | ||||
|  | ||||
| 	item := candidate.Content.Parts[0] | ||||
| 	if item.FunctionCall == nil { | ||||
| 		return toolCalls | ||||
| 	} | ||||
| 	argsBytes, err := json.Marshal(item.FunctionCall.Arguments) | ||||
| 	if err != nil { | ||||
| 		logger.FatalLog("getToolCalls failed: " + err.Error()) | ||||
| 		return toolCalls | ||||
| 	} | ||||
| 	toolCall := model.Tool{ | ||||
| 		Id:   fmt.Sprintf("call_%s", random.GetUUID()), | ||||
| 		Type: "function", | ||||
| 		Function: model.Function{ | ||||
| 			Arguments: string(argsBytes), | ||||
| 			Name:      item.FunctionCall.FunctionName, | ||||
| 		}, | ||||
| 	} | ||||
| 	toolCalls = append(toolCalls, toolCall) | ||||
| 	return toolCalls | ||||
| } | ||||
|  | ||||
| func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| @@ -165,13 +200,19 @@ func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 		choice := openai.TextResponseChoice{ | ||||
| 			Index: i, | ||||
| 			Message: model.Message{ | ||||
| 				Role:    "assistant", | ||||
| 				Content: "", | ||||
| 				Role: "assistant", | ||||
| 			}, | ||||
| 			FinishReason: constant.StopFinishReason, | ||||
| 		} | ||||
| 		if len(candidate.Content.Parts) > 0 { | ||||
| 			choice.Message.Content = candidate.Content.Parts[0].Text | ||||
| 			if candidate.Content.Parts[0].FunctionCall != nil { | ||||
| 				choice.Message.ToolCalls = getToolCalls(&candidate) | ||||
| 			} else { | ||||
| 				choice.Message.Content = candidate.Content.Parts[0].Text | ||||
| 			} | ||||
| 		} else { | ||||
| 			choice.Message.Content = "" | ||||
| 			choice.FinishReason = candidate.FinishReason | ||||
| 		} | ||||
| 		fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||||
| 	} | ||||
|   | ||||
| @@ -12,9 +12,15 @@ type InlineData struct { | ||||
| 	Data     string `json:"data"` | ||||
| } | ||||
|  | ||||
| type FunctionCall struct { | ||||
| 	FunctionName string `json:"name"` | ||||
| 	Arguments    any    `json:"args"` | ||||
| } | ||||
|  | ||||
| type Part struct { | ||||
| 	Text       string      `json:"text,omitempty"` | ||||
| 	InlineData *InlineData `json:"inlineData,omitempty"` | ||||
| 	Text         string        `json:"text,omitempty"` | ||||
| 	InlineData   *InlineData   `json:"inlineData,omitempty"` | ||||
| 	FunctionCall *FunctionCall `json:"functionCall,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatContent struct { | ||||
| @@ -28,7 +34,7 @@ type ChatSafetySettings struct { | ||||
| } | ||||
|  | ||||
| type ChatTools struct { | ||||
| 	FunctionDeclarations any `json:"functionDeclarations,omitempty"` | ||||
| 	FunctionDeclarations any `json:"function_declarations,omitempty"` | ||||
| } | ||||
|  | ||||
| type ChatGenerationConfig struct { | ||||
|   | ||||
| @@ -7,4 +7,6 @@ var ModelList = []string{ | ||||
| 	"llama2-7b-2048", | ||||
| 	"llama2-70b-4096", | ||||
| 	"mixtral-8x7b-32768", | ||||
| 	"llama3-8b-8192", | ||||
| 	"llama3-70b-8192", | ||||
| } | ||||
|   | ||||
| @@ -1,5 +1,11 @@ | ||||
| package ollama | ||||
|  | ||||
| var ModelList = []string{ | ||||
| 	"codellama:7b-instruct", | ||||
| 	"llama2:7b", | ||||
| 	"llama2:latest", | ||||
| 	"llama3:latest", | ||||
| 	"phi3:latest", | ||||
| 	"qwen:0.5b-chat", | ||||
| 	"qwen:7b", | ||||
| } | ||||
|   | ||||
| @@ -53,6 +53,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { | ||||
| 	} | ||||
| 	fullTextResponse := openai.TextResponse{ | ||||
| 		Id:      fmt.Sprintf("chatcmpl-%s", random.GetUUID()), | ||||
| 		Model:   response.Model, | ||||
| 		Object:  "chat.completion", | ||||
| 		Created: helper.GetTimestamp(), | ||||
| 		Choices: []openai.TextResponseChoice{choice}, | ||||
|   | ||||
| @@ -29,13 +29,13 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| 		if meta.Mode == relaymode.ImagesGenerations { | ||||
| 			// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api | ||||
| 			// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview | ||||
| 			fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) | ||||
| 			fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.Config.APIVersion) | ||||
| 			return fullRequestURL, nil | ||||
| 		} | ||||
|  | ||||
| 		// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api | ||||
| 		requestURL := strings.Split(meta.RequestURLPath, "?")[0] | ||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) | ||||
| 		requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.Config.APIVersion) | ||||
| 		task := strings.TrimPrefix(requestURL, "/v1/") | ||||
| 		model_ := meta.ActualModelName | ||||
| 		model_ = strings.Replace(model_, ".", "", -1) | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package openai | ||||
| import ( | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/ai360" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/baichuan" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/deepseek" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/groq" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/minimax" | ||||
| @@ -22,6 +23,7 @@ var CompatibleChannels = []int{ | ||||
| 	channeltype.Groq, | ||||
| 	channeltype.LingYiWanWu, | ||||
| 	channeltype.StepFun, | ||||
| 	channeltype.DeepSeek, | ||||
| } | ||||
|  | ||||
| func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| @@ -44,6 +46,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) { | ||||
| 		return "lingyiwanwu", lingyiwanwu.ModelList | ||||
| 	case channeltype.StepFun: | ||||
| 		return "stepfun", stepfun.ModelList | ||||
| 	case channeltype.DeepSeek: | ||||
| 		return "deepseek", deepseek.ModelList | ||||
| 	default: | ||||
| 		return "openai", ModelList | ||||
| 	} | ||||
|   | ||||
| @@ -15,6 +15,12 @@ import ( | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	dataPrefix       = "data: " | ||||
| 	done             = "[DONE]" | ||||
| 	dataPrefixLength = len(dataPrefix) | ||||
| ) | ||||
|  | ||||
| func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { | ||||
| 	responseText := "" | ||||
| 	scanner := bufio.NewScanner(resp.Body) | ||||
| @@ -36,39 +42,46 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E | ||||
| 	go func() { | ||||
| 		for scanner.Scan() { | ||||
| 			data := scanner.Text() | ||||
| 			if len(data) < 6 { // ignore blank line or wrong format | ||||
| 			if len(data) < dataPrefixLength { // ignore blank line or wrong format | ||||
| 				continue | ||||
| 			} | ||||
| 			if data[:6] != "data: " && data[:6] != "[DONE]" { | ||||
| 			if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { | ||||
| 				continue | ||||
| 			} | ||||
| 			dataChan <- data | ||||
| 			data = data[6:] | ||||
| 			if !strings.HasPrefix(data, "[DONE]") { | ||||
| 				switch relayMode { | ||||
| 				case relaymode.ChatCompletions: | ||||
| 					var streamResponse ChatCompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue // just ignore the error | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += conv.AsString(choice.Delta.Content) | ||||
| 					} | ||||
| 					if streamResponse.Usage != nil { | ||||
| 						usage = streamResponse.Usage | ||||
| 					} | ||||
| 				case relaymode.Completions: | ||||
| 					var streamResponse CompletionsStreamResponse | ||||
| 					err := json.Unmarshal([]byte(data), &streamResponse) | ||||
| 					if err != nil { | ||||
| 						logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 						continue | ||||
| 					} | ||||
| 					for _, choice := range streamResponse.Choices { | ||||
| 						responseText += choice.Text | ||||
| 					} | ||||
| 			if strings.HasPrefix(data[dataPrefixLength:], done) { | ||||
| 				dataChan <- data | ||||
| 				continue | ||||
| 			} | ||||
| 			switch relayMode { | ||||
| 			case relaymode.ChatCompletions: | ||||
| 				var streamResponse ChatCompletionsStreamResponse | ||||
| 				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 					dataChan <- data // if error happened, pass the data to client | ||||
| 					continue         // just ignore the error | ||||
| 				} | ||||
| 				if len(streamResponse.Choices) == 0 { | ||||
| 					// but for empty choice, we should not pass it to client, this is for azure | ||||
| 					continue // just ignore empty choice | ||||
| 				} | ||||
| 				dataChan <- data | ||||
| 				for _, choice := range streamResponse.Choices { | ||||
| 					responseText += conv.AsString(choice.Delta.Content) | ||||
| 				} | ||||
| 				if streamResponse.Usage != nil { | ||||
| 					usage = streamResponse.Usage | ||||
| 				} | ||||
| 			case relaymode.Completions: | ||||
| 				dataChan <- data | ||||
| 				var streamResponse CompletionsStreamResponse | ||||
| 				err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) | ||||
| 				if err != nil { | ||||
| 					logger.SysError("error unmarshalling stream response: " + err.Error()) | ||||
| 					continue | ||||
| 				} | ||||
| 				for _, choice := range streamResponse.Choices { | ||||
| 					responseText += choice.Text | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|   | ||||
| @@ -134,7 +134,7 @@ type ChatCompletionsStreamResponse struct { | ||||
| 	Created int64                                 `json:"created"` | ||||
| 	Model   string                                `json:"model"` | ||||
| 	Choices []ChatCompletionsStreamResponseChoice `json:"choices"` | ||||
| 	Usage   *model.Usage                          `json:"usage"` | ||||
| 	Usage   *model.Usage                          `json:"usage,omitempty"` | ||||
| } | ||||
|  | ||||
| type CompletionsStreamResponse struct { | ||||
|   | ||||
| @@ -14,10 +14,11 @@ import ( | ||||
|  | ||||
| type Adaptor struct { | ||||
| 	request *model.GeneralOpenAIRequest | ||||
| 	meta    *meta.Meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) Init(meta *meta.Meta) { | ||||
|  | ||||
| 	a.meta = meta | ||||
| } | ||||
|  | ||||
| func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
| @@ -26,6 +27,14 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { | ||||
|  | ||||
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { | ||||
| 	adaptor.SetupCommonRequestHeader(c, req, meta) | ||||
| 	version := parseAPIVersionByModelName(meta.ActualModelName) | ||||
| 	if version == "" { | ||||
| 		version = a.meta.Config.APIVersion | ||||
| 	} | ||||
| 	if version == "" { | ||||
| 		version = "v1.1" | ||||
| 	} | ||||
| 	a.meta.Config.APIVersion = version | ||||
| 	// check DoResponse for auth part | ||||
| 	return nil | ||||
| } | ||||
| @@ -61,9 +70,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met | ||||
| 		return nil, openai.ErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) | ||||
| 	} | ||||
| 	if meta.IsStream { | ||||
| 		err, usage = StreamHandler(c, *a.request, splits[0], splits[1], splits[2]) | ||||
| 		err, usage = StreamHandler(c, meta, *a.request, splits[0], splits[1], splits[2]) | ||||
| 	} else { | ||||
| 		err, usage = Handler(c, *a.request, splits[0], splits[1], splits[2]) | ||||
| 		err, usage = Handler(c, meta, *a.request, splits[0], splits[1], splits[2]) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
|   | ||||
| @@ -9,12 +9,12 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/gorilla/websocket" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/helper" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/common/random" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/constant" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	"github.com/songquanpeng/one-api/relay/model" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| @@ -149,8 +149,8 @@ func buildXunfeiAuthUrl(hostUrl string, apiKey, apiSecret string) string { | ||||
| 	return callUrl | ||||
| } | ||||
|  | ||||
| func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) | ||||
| func StreamHandler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) | ||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil | ||||
| @@ -179,8 +179,8 @@ func StreamHandler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId | ||||
| 	return nil, &usage | ||||
| } | ||||
|  | ||||
| func Handler(c *gin.Context, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(c, apiKey, apiSecret, textRequest.Model) | ||||
| func Handler(c *gin.Context, meta *meta.Meta, textRequest model.GeneralOpenAIRequest, appId string, apiSecret string, apiKey string) (*model.ErrorWithStatusCode, *model.Usage) { | ||||
| 	domain, authUrl := getXunfeiAuthUrl(meta.Config.APIVersion, apiKey, apiSecret) | ||||
| 	dataChan, stopChan, err := xunfeiMakeRequest(textRequest, domain, authUrl, appId) | ||||
| 	if err != nil { | ||||
| 		return openai.ErrorWrapper(err, "xunfei_request_failed", http.StatusInternalServerError), nil | ||||
| @@ -268,25 +268,12 @@ func xunfeiMakeRequest(textRequest model.GeneralOpenAIRequest, domain, authUrl, | ||||
| 	return dataChan, stopChan, nil | ||||
| } | ||||
|  | ||||
| func getAPIVersion(c *gin.Context, modelName string) string { | ||||
| 	query := c.Request.URL.Query() | ||||
| 	apiVersion := query.Get("api-version") | ||||
| 	if apiVersion != "" { | ||||
| 		return apiVersion | ||||
| 	} | ||||
| func parseAPIVersionByModelName(modelName string) string { | ||||
| 	parts := strings.Split(modelName, "-") | ||||
| 	if len(parts) == 2 { | ||||
| 		apiVersion = parts[1] | ||||
| 		return apiVersion | ||||
|  | ||||
| 		return parts[1] | ||||
| 	} | ||||
| 	apiVersion = c.GetString(config.KeyAPIVersion) | ||||
| 	if apiVersion != "" { | ||||
| 		return apiVersion | ||||
| 	} | ||||
| 	apiVersion = "v1.1" | ||||
| 	logger.SysLog("api_version not found, using default: " + apiVersion) | ||||
| 	return apiVersion | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| // https://www.xfyun.cn/doc/spark/Web.html#_1-%E6%8E%A5%E5%8F%A3%E8%AF%B4%E6%98%8E | ||||
| @@ -304,8 +291,7 @@ func apiVersion2domain(apiVersion string) string { | ||||
| 	return "general" + apiVersion | ||||
| } | ||||
|  | ||||
| func getXunfeiAuthUrl(c *gin.Context, apiKey string, apiSecret string, modelName string) (string, string) { | ||||
| 	apiVersion := getAPIVersion(c, modelName) | ||||
| func getXunfeiAuthUrl(apiVersion string, apiKey string, apiSecret string) (string, string) { | ||||
| 	domain := apiVersion2domain(apiVersion) | ||||
| 	authUrl := buildXunfeiAuthUrl(fmt.Sprintf("wss://spark-api.xf-yun.com/%s/chat", apiVersion), apiKey, apiSecret) | ||||
| 	return domain, authUrl | ||||
|   | ||||
| @@ -13,6 +13,9 @@ const ( | ||||
| 	Gemini | ||||
| 	Ollama | ||||
| 	AwsClaude | ||||
| 	Coze | ||||
| 	Cohere | ||||
| 	Cloudflare | ||||
|  | ||||
| 	Dummy // this one is only for count, do not add any channel after this | ||||
| ) | ||||
|   | ||||
| @@ -2,8 +2,9 @@ package ratio | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"strings" | ||||
|  | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| @@ -147,11 +148,13 @@ var ModelRatio = map[string]float64{ | ||||
| 	"mistral-medium-latest": 2.7 / 1000 * USD, | ||||
| 	"mistral-large-latest":  8.0 / 1000 * USD, | ||||
| 	"mistral-embed":         0.1 / 1000 * USD, | ||||
| 	// https://wow.groq.com/ | ||||
| 	"llama2-70b-4096":    0.7 / 1000 * USD, | ||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||
| 	// https://wow.groq.com/#:~:text=inquiries%C2%A0here.-,Model,-Current%20Speed | ||||
| 	"llama3-70b-8192":    0.59 / 1000 * USD, | ||||
| 	"mixtral-8x7b-32768": 0.27 / 1000 * USD, | ||||
| 	"llama3-8b-8192":     0.05 / 1000 * USD, | ||||
| 	"gemma-7b-it":        0.1 / 1000 * USD, | ||||
| 	"llama2-70b-4096":    0.64 / 1000 * USD, | ||||
| 	"llama2-7b-2048":     0.1 / 1000 * USD, | ||||
| 	// https://platform.lingyiwanwu.com/docs#-计费单元 | ||||
| 	"yi-34b-chat-0205": 2.5 / 1000 * RMB, | ||||
| 	"yi-34b-chat-200k": 12.0 / 1000 * RMB, | ||||
| @@ -160,6 +163,16 @@ var ModelRatio = map[string]float64{ | ||||
| 	"step-1v-32k": 0.024 * RMB, | ||||
| 	"step-1-32k":  0.024 * RMB, | ||||
| 	"step-1-200k": 0.15 * RMB, | ||||
| 	// https://cohere.com/pricing | ||||
| 	"command":               0.5, | ||||
| 	"command-nightly":       0.5, | ||||
| 	"command-light":         0.5, | ||||
| 	"command-light-nightly": 0.5, | ||||
| 	"command-r":             0.5 / 1000 * USD, | ||||
| 	"command-r-plus	":       3.0 / 1000 * USD, | ||||
| 	// https://platform.deepseek.com/api-docs/pricing/ | ||||
| 	"deepseek-chat":  1.0 / 1000 * RMB, | ||||
| 	"deepseek-coder": 1.0 / 1000 * RMB, | ||||
| } | ||||
|  | ||||
| var CompletionRatio = map[string]float64{} | ||||
| @@ -215,6 +228,9 @@ func GetModelRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "qwen-") && strings.HasSuffix(name, "-internet") { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "command-") && strings.HasSuffix(name, "-internet") { | ||||
| 		name = strings.TrimSuffix(name, "-internet") | ||||
| 	} | ||||
| 	ratio, ok := ModelRatio[name] | ||||
| 	if !ok { | ||||
| 		ratio, ok = DefaultModelRatio[name] | ||||
| @@ -258,7 +274,7 @@ func GetCompletionRatio(name string) float64 { | ||||
| 		return 4.0 / 3.0 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "gpt-4") { | ||||
| 		if strings.HasPrefix(name, "gpt-4-turbo") { | ||||
| 		if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") { | ||||
| 			return 3 | ||||
| 		} | ||||
| 		return 2 | ||||
| @@ -275,9 +291,22 @@ func GetCompletionRatio(name string) float64 { | ||||
| 	if strings.HasPrefix(name, "gemini-") { | ||||
| 		return 3 | ||||
| 	} | ||||
| 	if strings.HasPrefix(name, "deepseek-") { | ||||
| 		return 2 | ||||
| 	} | ||||
| 	switch name { | ||||
| 	case "llama2-70b-4096": | ||||
| 		return 0.8 / 0.7 | ||||
| 		return 0.8 / 0.64 | ||||
| 	case "llama3-8b-8192": | ||||
| 		return 2 | ||||
| 	case "llama3-70b-8192": | ||||
| 		return 0.79 / 0.59 | ||||
| 	case "command", "command-light", "command-nightly", "command-light-nightly": | ||||
| 		return 2 | ||||
| 	case "command-r": | ||||
| 		return 3 | ||||
| 	case "command-r-plus": | ||||
| 		return 5 | ||||
| 	} | ||||
| 	return 1 | ||||
| } | ||||
|   | ||||
| @@ -35,6 +35,10 @@ const ( | ||||
| 	LingYiWanWu | ||||
| 	StepFun | ||||
| 	AwsClaude | ||||
| 	Coze | ||||
| 	Cohere | ||||
| 	DeepSeek | ||||
| 	Cloudflare | ||||
|  | ||||
| 	Dummy | ||||
| ) | ||||
|   | ||||
| @@ -27,6 +27,12 @@ func ToAPIType(channelType int) int { | ||||
| 		apiType = apitype.Ollama | ||||
| 	case AwsClaude: | ||||
| 		apiType = apitype.AwsClaude | ||||
| 	case Coze: | ||||
| 		apiType = apitype.Coze | ||||
| 	case Cohere: | ||||
| 		apiType = apitype.Cohere | ||||
| 	case Cloudflare: | ||||
| 		apiType = apitype.Cloudflare | ||||
| 	} | ||||
|  | ||||
| 	return apiType | ||||
|   | ||||
| @@ -35,6 +35,10 @@ var ChannelBaseURLs = []string{ | ||||
| 	"https://api.lingyiwanwu.com",               // 31 | ||||
| 	"https://api.stepfun.com",                   // 32 | ||||
| 	"",                                          // 33 | ||||
| 	"https://api.coze.com",                      // 34 | ||||
| 	"https://api.cohere.ai",                     // 35 | ||||
| 	"https://api.deepseek.com",                  // 36 | ||||
| 	"https://api.cloudflare.com",                // 37 | ||||
| } | ||||
|  | ||||
| func init() { | ||||
|   | ||||
| @@ -10,14 +10,15 @@ import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/azure" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/openai" | ||||
| 	"github.com/songquanpeng/one-api/relay/billing" | ||||
| 	billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/client" | ||||
| 	"github.com/songquanpeng/one-api/relay/meta" | ||||
| 	relaymodel "github.com/songquanpeng/one-api/relay/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"io" | ||||
| @@ -27,14 +28,15 @@ import ( | ||||
|  | ||||
| func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { | ||||
| 	ctx := c.Request.Context() | ||||
| 	meta := meta.GetByContext(c) | ||||
| 	audioModel := "whisper-1" | ||||
|  | ||||
| 	tokenId := c.GetInt("token_id") | ||||
| 	channelType := c.GetInt("channel") | ||||
| 	channelId := c.GetInt("channel_id") | ||||
| 	userId := c.GetInt("id") | ||||
| 	group := c.GetString("group") | ||||
| 	tokenName := c.GetString("token_name") | ||||
| 	tokenId := c.GetInt(ctxkey.TokenId) | ||||
| 	channelType := c.GetInt(ctxkey.Channel) | ||||
| 	channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 	userId := c.GetInt(ctxkey.Id) | ||||
| 	group := c.GetString(ctxkey.Group) | ||||
| 	tokenName := c.GetString(ctxkey.TokenName) | ||||
|  | ||||
| 	var ttsRequest openai.TextToSpeechRequest | ||||
| 	if relayMode == relaymode.AudioSpeech { | ||||
| @@ -107,7 +109,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	}() | ||||
|  | ||||
| 	// map model name | ||||
| 	modelMapping := c.GetString("model_mapping") | ||||
| 	modelMapping := c.GetString(ctxkey.ModelMapping) | ||||
| 	if modelMapping != "" { | ||||
| 		modelMap := make(map[string]string) | ||||
| 		err := json.Unmarshal([]byte(modelMapping), &modelMap) | ||||
| @@ -121,13 +123,13 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
|  | ||||
| 	baseURL := channeltype.ChannelBaseURLs[channelType] | ||||
| 	requestURL := c.Request.URL.String() | ||||
| 	if c.GetString("base_url") != "" { | ||||
| 		baseURL = c.GetString("base_url") | ||||
| 	if c.GetString(ctxkey.BaseURL) != "" { | ||||
| 		baseURL = c.GetString(ctxkey.BaseURL) | ||||
| 	} | ||||
|  | ||||
| 	fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) | ||||
| 	if channelType == channeltype.Azure { | ||||
| 		apiVersion := azure.GetAPIVersion(c) | ||||
| 		apiVersion := meta.Config.APIVersion | ||||
| 		if relayMode == relaymode.AudioTranscription { | ||||
| 			// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api | ||||
| 			fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) | ||||
|   | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/common/logger" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay" | ||||
| @@ -69,6 +70,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 	if adaptor == nil { | ||||
| 		return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
|  | ||||
| 	switch meta.ChannelType { | ||||
| 	case channeltype.Ali: | ||||
| @@ -119,11 +121,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus | ||||
| 			logger.SysError("error update user quota cache: " + err.Error()) | ||||
| 		} | ||||
| 		if quota != 0 { | ||||
| 			tokenName := c.GetString("token_name") | ||||
| 			tokenName := c.GetString(ctxkey.TokenName) | ||||
| 			logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) | ||||
| 			model.RecordConsumeLog(ctx, meta.UserId, meta.ChannelId, 0, 0, imageRequest.Model, tokenName, quota, logContent) | ||||
| 			model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) | ||||
| 			channelId := c.GetInt("channel_id") | ||||
| 			channelId := c.GetInt(ctxkey.ChannelId) | ||||
| 			model.UpdateChannelUsedQuota(channelId, quota) | ||||
| 		} | ||||
| 	}(c.Request.Context()) | ||||
|   | ||||
| @@ -53,6 +53,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { | ||||
| 	if adaptor == nil { | ||||
| 		return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) | ||||
| 	} | ||||
| 	adaptor.Init(meta) | ||||
|  | ||||
| 	// get request body | ||||
| 	var requestBody io.Reader | ||||
|   | ||||
| @@ -2,8 +2,8 @@ package meta | ||||
|  | ||||
| import ( | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/songquanpeng/one-api/common/config" | ||||
| 	"github.com/songquanpeng/one-api/relay/adaptor/azure" | ||||
| 	"github.com/songquanpeng/one-api/common/ctxkey" | ||||
| 	"github.com/songquanpeng/one-api/model" | ||||
| 	"github.com/songquanpeng/one-api/relay/channeltype" | ||||
| 	"github.com/songquanpeng/one-api/relay/relaymode" | ||||
| 	"strings" | ||||
| @@ -19,10 +19,9 @@ type Meta struct { | ||||
| 	Group           string | ||||
| 	ModelMapping    map[string]string | ||||
| 	BaseURL         string | ||||
| 	APIVersion      string | ||||
| 	APIKey          string | ||||
| 	APIType         int | ||||
| 	Config          map[string]string | ||||
| 	Config          model.ChannelConfig | ||||
| 	IsStream        bool | ||||
| 	OriginModelName string | ||||
| 	ActualModelName string | ||||
| @@ -32,22 +31,22 @@ type Meta struct { | ||||
|  | ||||
| func GetByContext(c *gin.Context) *Meta { | ||||
| 	meta := Meta{ | ||||
| 		Mode:           relaymode.GetByPath(c.Request.URL.Path), | ||||
| 		ChannelType:    c.GetInt("channel"), | ||||
| 		ChannelId:      c.GetInt("channel_id"), | ||||
| 		TokenId:        c.GetInt("token_id"), | ||||
| 		TokenName:      c.GetString("token_name"), | ||||
| 		UserId:         c.GetInt("id"), | ||||
| 		Group:          c.GetString("group"), | ||||
| 		ModelMapping:   c.GetStringMapString("model_mapping"), | ||||
| 		BaseURL:        c.GetString("base_url"), | ||||
| 		APIVersion:     c.GetString(config.KeyAPIVersion), | ||||
| 		APIKey:         strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), | ||||
| 		Config:         nil, | ||||
| 		RequestURLPath: c.Request.URL.String(), | ||||
| 		Mode:            relaymode.GetByPath(c.Request.URL.Path), | ||||
| 		ChannelType:     c.GetInt(ctxkey.Channel), | ||||
| 		ChannelId:       c.GetInt(ctxkey.ChannelId), | ||||
| 		TokenId:         c.GetInt(ctxkey.TokenId), | ||||
| 		TokenName:       c.GetString(ctxkey.TokenName), | ||||
| 		UserId:          c.GetInt(ctxkey.Id), | ||||
| 		Group:           c.GetString(ctxkey.Group), | ||||
| 		ModelMapping:    c.GetStringMapString(ctxkey.ModelMapping), | ||||
| 		OriginModelName: c.GetString(ctxkey.RequestModel), | ||||
| 		BaseURL:         c.GetString(ctxkey.BaseURL), | ||||
| 		APIKey:          strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), | ||||
| 		RequestURLPath:  c.Request.URL.String(), | ||||
| 	} | ||||
| 	if meta.ChannelType == channeltype.Azure { | ||||
| 		meta.APIVersion = azure.GetAPIVersion(c) | ||||
| 	cfg, ok := c.Get(ctxkey.Config) | ||||
| 	if ok { | ||||
| 		meta.Config = cfg.(model.ChannelConfig) | ||||
| 	} | ||||
| 	if meta.BaseURL == "" { | ||||
| 		meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] | ||||
|   | ||||
| @@ -11,6 +11,12 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 14, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   // 33: { | ||||
|   //   key: 33, | ||||
|   //   text: 'AWS Claude', | ||||
|   //   value: 33, | ||||
|   //   color: 'primary' | ||||
|   // }, | ||||
|   3: { | ||||
|     key: 3, | ||||
|     text: 'Azure OpenAI', | ||||
| @@ -113,6 +119,24 @@ export const CHANNEL_OPTIONS = { | ||||
|     value: 32, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   // 34: { | ||||
|   //   key: 34, | ||||
|   //   text: 'Coze', | ||||
|   //   value: 34, | ||||
|   //   color: 'primary' | ||||
|   // }, | ||||
|   35: { | ||||
|     key: 35, | ||||
|     text: 'Cohere', | ||||
|     value: 35, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   36: { | ||||
|     key: 36, | ||||
|     text: 'DeepSeek', | ||||
|     value: 36, | ||||
|     color: 'primary' | ||||
|   }, | ||||
|   8: { | ||||
|     key: 8, | ||||
|     text: '自定义渠道', | ||||
|   | ||||
| @@ -33,7 +33,7 @@ function renderType(type) { | ||||
|     } | ||||
|     type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; | ||||
|   } | ||||
|   return <Label basic color={type2label[type]?.color}>{type2label[type]?.text}</Label>; | ||||
|   return <Label basic color={type2label[type]?.color}>{type2label[type] ? type2label[type].text : type}</Label>; | ||||
| } | ||||
|  | ||||
| function renderBalance(type, balance) { | ||||
|   | ||||
| @@ -13,6 +13,7 @@ const COPY_OPTIONS = [ | ||||
| ]; | ||||
|  | ||||
| const OPEN_LINK_OPTIONS = [ | ||||
|   { key: 'next', text: 'ChatGPT Next Web', value: 'next' }, | ||||
|   { key: 'ama', text: 'BotGem', value: 'ama' }, | ||||
|   { key: 'opencat', text: 'OpenCat', value: 'opencat' }, | ||||
| ]; | ||||
|   | ||||
| @@ -1,35 +1,39 @@ | ||||
| export const CHANNEL_OPTIONS = [ | ||||
|   { key: 1, text: 'OpenAI', value: 1, color: 'green' }, | ||||
|   { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, | ||||
|   { key: 33, text: 'AWS Claude', value: 33, color: 'black' }, | ||||
|   { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, | ||||
|   { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, | ||||
|   { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, | ||||
|   { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, | ||||
|   { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, | ||||
|   { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, | ||||
|   { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, | ||||
|   { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, | ||||
|   { key: 19, text: '360 智脑', value: 19, color: 'blue' }, | ||||
|   { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, | ||||
|   { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, | ||||
|   { key: 26, text: '百川大模型', value: 26, color: 'orange' }, | ||||
|   { key: 27, text: 'MiniMax', value: 27, color: 'red' }, | ||||
|   { key: 29, text: 'Groq', value: 29, color: 'orange' }, | ||||
|   { key: 30, text: 'Ollama', value: 30, color: 'black' }, | ||||
|   { key: 31, text: '零一万物', value: 31, color: 'green' }, | ||||
|   { key: 32, text: '阶跃星辰', value: 32, color: 'blue' }, | ||||
|   { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, | ||||
|   { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, | ||||
|   { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, | ||||
|   { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, | ||||
|   { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, | ||||
|   { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, | ||||
|   { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, | ||||
|   { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, | ||||
|   { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, | ||||
|   { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, | ||||
|   { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, | ||||
|   { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, | ||||
|   { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } | ||||
|     {key: 1, text: 'OpenAI', value: 1, color: 'green'}, | ||||
|     {key: 14, text: 'Anthropic Claude', value: 14, color: 'black'}, | ||||
|     {key: 33, text: 'AWS Claude', value: 33, color: 'black'}, | ||||
|     {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'}, | ||||
|     {key: 11, text: 'Google PaLM2', value: 11, color: 'orange'}, | ||||
|     {key: 24, text: 'Google Gemini', value: 24, color: 'orange'}, | ||||
|     {key: 28, text: 'Mistral AI', value: 28, color: 'orange'}, | ||||
|     {key: 15, text: '百度文心千帆', value: 15, color: 'blue'}, | ||||
|     {key: 17, text: '阿里通义千问', value: 17, color: 'orange'}, | ||||
|     {key: 18, text: '讯飞星火认知', value: 18, color: 'blue'}, | ||||
|     {key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet'}, | ||||
|     {key: 19, text: '360 智脑', value: 19, color: 'blue'}, | ||||
|     {key: 25, text: 'Moonshot AI', value: 25, color: 'black'}, | ||||
|     {key: 23, text: '腾讯混元', value: 23, color: 'teal'}, | ||||
|     {key: 26, text: '百川大模型', value: 26, color: 'orange'}, | ||||
|     {key: 27, text: 'MiniMax', value: 27, color: 'red'}, | ||||
|     {key: 29, text: 'Groq', value: 29, color: 'orange'}, | ||||
|     {key: 30, text: 'Ollama', value: 30, color: 'black'}, | ||||
|     {key: 31, text: '零一万物', value: 31, color: 'green'}, | ||||
|     {key: 32, text: '阶跃星辰', value: 32, color: 'blue'}, | ||||
|     {key: 34, text: 'Coze', value: 34, color: 'blue'}, | ||||
|     {key: 35, text: 'Cohere', value: 35, color: 'blue'}, | ||||
|     {key: 36, text: 'DeepSeek', value: 36, color: 'black'}, | ||||
|     {key: 37, text: 'Cloudflare', value: 37, color: 'orange'}, | ||||
|     {key: 8, text: '自定义渠道', value: 8, color: 'pink'}, | ||||
|     {key: 22, text: '知识库:FastGPT', value: 22, color: 'blue'}, | ||||
|     {key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple'}, | ||||
|     {key: 20, text: '代理:OpenRouter', value: 20, color: 'black'}, | ||||
|     {key: 2, text: '代理:API2D', value: 2, color: 'blue'}, | ||||
|     {key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown'}, | ||||
|     {key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple'}, | ||||
|     {key: 10, text: '代理:AI Proxy', value: 10, color: 'purple'}, | ||||
|     {key: 4, text: '代理:CloseAI', value: 4, color: 'teal'}, | ||||
|     {key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet'}, | ||||
|     {key: 9, text: '代理:AI.LS', value: 9, color: 'yellow'}, | ||||
|     {key: 12, text: '代理:API2GPT', value: 12, color: 'blue'}, | ||||
|     {key: 13, text: '代理:AIGC2D', value: 13, color: 'purple'} | ||||
| ]; | ||||
|   | ||||
| @@ -57,7 +57,8 @@ const EditChannel = () => { | ||||
|   const [config, setConfig] = useState({ | ||||
|     region: '', | ||||
|     sk: '', | ||||
|     ak: '' | ||||
|     ak: '', | ||||
|     user_id: '' | ||||
|   }); | ||||
|   const handleInputChange = (e, { name, value }) => { | ||||
|     setInputs((inputs) => ({ ...inputs, [name]: value })); | ||||
| @@ -156,8 +157,10 @@ const EditChannel = () => { | ||||
|   }, []); | ||||
|  | ||||
|   const submit = async () => { | ||||
|     if (inputs.key === "") { | ||||
|       inputs.key = `${config.ak}|${config.sk}|${config.region}`; | ||||
|     if (inputs.key === '') { | ||||
|       if (config.ak !== '' && config.sk !== '' && config.region !== '') { | ||||
|         inputs.key = `${config.ak}|${config.sk}|${config.region}`; | ||||
|       } | ||||
|     } | ||||
|     if (!isEdit && (inputs.name === '' || inputs.key === '')) { | ||||
|       showInfo('请填写渠道名称和渠道密钥!'); | ||||
| @@ -171,7 +174,7 @@ const EditChannel = () => { | ||||
|       showInfo('模型映射必须是合法的 JSON 格式!'); | ||||
|       return; | ||||
|     } | ||||
|     let localInputs = inputs; | ||||
|     let localInputs = {...inputs}; | ||||
|     if (localInputs.base_url && localInputs.base_url.endsWith('/')) { | ||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||
|     } | ||||
| @@ -352,6 +355,13 @@ const EditChannel = () => { | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           { | ||||
|             inputs.type === 34 && ( | ||||
|               <Message> | ||||
|                 对于 Coze 而言,模型名称即 Bot ID,你可以添加一个前缀 `bot-`,例如:`bot-123456`。 | ||||
|               </Message> | ||||
|             ) | ||||
|           } | ||||
|           <Form.Field> | ||||
|             <Form.Dropdown | ||||
|               label='模型' | ||||
| @@ -442,6 +452,18 @@ const EditChannel = () => { | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           { | ||||
|             inputs.type === 34 && ( | ||||
|               <Form.Input | ||||
|                 label='User ID' | ||||
|                 name='user_id' | ||||
|                 required | ||||
|                 placeholder={'生成该密钥的用户 ID'} | ||||
|                 onChange={handleConfigChange} | ||||
|                 value={config.user_id} | ||||
|                 autoComplete='' | ||||
|               />) | ||||
|           } | ||||
|           { | ||||
|             inputs.type !== 33 && (batch ? <Form.Field> | ||||
|               <Form.TextArea | ||||
| @@ -466,6 +488,21 @@ const EditChannel = () => { | ||||
|               /> | ||||
|             </Form.Field>) | ||||
|           } | ||||
|           { | ||||
|             inputs.type === 37 && ( | ||||
|               <Form.Field> | ||||
|                 <Form.Input | ||||
|                   label='Account ID' | ||||
|                   name='user_id' | ||||
|                   required | ||||
|                   placeholder={'请输入 Account ID,例如:d8d7c61dbc334c32d3ced580e4bf42b4'} | ||||
|                   onChange={handleConfigChange} | ||||
|                   value={config.user_id} | ||||
|                   autoComplete='' | ||||
|                 /> | ||||
|               </Form.Field> | ||||
|             ) | ||||
|           } | ||||
|           { | ||||
|             inputs.type !== 33 && !isEdit && ( | ||||
|               <Form.Checkbox | ||||
|   | ||||
		Reference in New Issue
	
	Block a user