mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 13:53:41 +08:00 
			
		
		
		
	Compare commits
	
		
			7 Commits
		
	
	
		
			v0.5.6-alp
			...
			v0.5.6-alp
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|  | de868e4e4e | ||
|  | 1d258cc898 | ||
|  | 37e09d764c | ||
|  | 159b9e3369 | ||
|  | 92001986db | ||
|  | a5647b1ea7 | ||
|  | 215e54fc96 | 
| @@ -269,6 +269,12 @@ docker run --name chatgpt-web -d -p 3002:3002 -e OPENAI_API_BASE_URL=https://ope | ||||
|  | ||||
| 注意,具体的 API Base 的格式取决于你所使用的客户端。 | ||||
|  | ||||
| 例如对于 OpenAI 的官方库: | ||||
| ```bash | ||||
| OPENAI_API_KEY="sk-xxxxxx" | ||||
| OPENAI_API_BASE="https://<HOST>:<PORT>/v1"  | ||||
| ``` | ||||
|  | ||||
| ```mermaid | ||||
| graph LR | ||||
|     A(用户) | ||||
|   | ||||
| @@ -111,7 +111,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He | ||||
| } | ||||
|  | ||||
| func updateChannelCloseAIBalance(channel *model.Channel) (float64, error) { | ||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.BaseURL) | ||||
| 	url := fmt.Sprintf("%s/dashboard/billing/credit_grants", channel.GetBaseURL()) | ||||
| 	body, err := GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) | ||||
|  | ||||
| 	if err != nil { | ||||
| @@ -201,18 +201,18 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { | ||||
|  | ||||
| func updateChannelBalance(channel *model.Channel) (float64, error) { | ||||
| 	baseURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.BaseURL == "" { | ||||
| 		channel.BaseURL = baseURL | ||||
| 	if channel.GetBaseURL() == "" { | ||||
| 		channel.BaseURL = &baseURL | ||||
| 	} | ||||
| 	switch channel.Type { | ||||
| 	case common.ChannelTypeOpenAI: | ||||
| 		if channel.BaseURL != "" { | ||||
| 			baseURL = channel.BaseURL | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			baseURL = channel.GetBaseURL() | ||||
| 		} | ||||
| 	case common.ChannelTypeAzure: | ||||
| 		return 0, errors.New("尚未实现") | ||||
| 	case common.ChannelTypeCustom: | ||||
| 		baseURL = channel.BaseURL | ||||
| 		baseURL = channel.GetBaseURL() | ||||
| 	case common.ChannelTypeCloseAI: | ||||
| 		return updateChannelCloseAIBalance(channel) | ||||
| 	case common.ChannelTypeOpenAISB: | ||||
|   | ||||
| @@ -42,10 +42,10 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai | ||||
| 	} | ||||
| 	requestURL := common.ChannelBaseURLs[channel.Type] | ||||
| 	if channel.Type == common.ChannelTypeAzure { | ||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model) | ||||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.GetBaseURL(), request.Model) | ||||
| 	} else { | ||||
| 		if channel.BaseURL != "" { | ||||
| 			requestURL = channel.BaseURL | ||||
| 		if channel.GetBaseURL() != "" { | ||||
| 			requestURL = channel.GetBaseURL() | ||||
| 		} | ||||
| 		requestURL += "/v1/chat/completions" | ||||
| 	} | ||||
|   | ||||
| @@ -82,9 +82,9 @@ func Distribute() func(c *gin.Context) { | ||||
| 		c.Set("channel", channel.Type) | ||||
| 		c.Set("channel_id", channel.Id) | ||||
| 		c.Set("channel_name", channel.Name) | ||||
| 		c.Set("model_mapping", channel.ModelMapping) | ||||
| 		c.Set("model_mapping", channel.GetModelMapping()) | ||||
| 		c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) | ||||
| 		c.Set("base_url", channel.BaseURL) | ||||
| 		c.Set("base_url", channel.GetBaseURL()) | ||||
| 		switch channel.Type { | ||||
| 		case common.ChannelTypeAzure: | ||||
| 			c.Set("api_version", channel.Other) | ||||
|   | ||||
| @@ -10,16 +10,18 @@ type Ability struct { | ||||
| 	Model     string `json:"model" gorm:"primaryKey;autoIncrement:false"` | ||||
| 	ChannelId int    `json:"channel_id" gorm:"primaryKey;autoIncrement:false;index"` | ||||
| 	Enabled   bool   `json:"enabled"` | ||||
| 	Priority  int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| 	Priority  *int64 `json:"priority" gorm:"bigint;default:0;index"` | ||||
| } | ||||
|  | ||||
| func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) { | ||||
| 	ability := Ability{} | ||||
| 	var err error = nil | ||||
| 	maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where("`group` = ? and model = ? and enabled = 1", group, model) | ||||
| 	channelQuery := DB.Where("`group` = ? and model = ? and enabled = 1 and priority = (?)", group, model, maxPrioritySubQuery) | ||||
| 	if common.UsingSQLite { | ||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RANDOM() END DESC ").Limit(1).First(&ability).Error | ||||
| 		err = channelQuery.Order("RANDOM()").Limit(1).First(&ability).Error | ||||
| 	} else { | ||||
| 		err = DB.Where("`group` = ? and model = ? and enabled = 1", group, model).Order("CASE WHEN priority <> 0 THEN priority ELSE RAND() END DESC").Limit(1).First(&ability).Error | ||||
| 		err = channelQuery.Order("RAND()").Limit(1).First(&ability).Error | ||||
| 	} | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
|   | ||||
| @@ -165,7 +165,7 @@ func InitChannelCache() { | ||||
| 	for group, model2channels := range newGroup2model2channels { | ||||
| 		for model, channels := range model2channels { | ||||
| 			sort.Slice(channels, func(i, j int) bool { | ||||
| 				return channels[i].Priority > channels[j].Priority | ||||
| 				return channels[i].GetPriority() > channels[j].GetPriority() | ||||
| 			}) | ||||
| 			newGroup2model2channels[group][model] = channels | ||||
| 		} | ||||
| @@ -195,11 +195,17 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error | ||||
| 	if len(channels) == 0 { | ||||
| 		return nil, errors.New("channel not found") | ||||
| 	} | ||||
| 	endIdx := len(channels) | ||||
| 	// choose by priority | ||||
| 	firstChannel := channels[0] | ||||
| 	if firstChannel.Priority > 0 { | ||||
| 		return firstChannel, nil | ||||
| 	if firstChannel.GetPriority() > 0 { | ||||
| 		for i := range channels { | ||||
| 			if channels[i].GetPriority() != firstChannel.GetPriority() { | ||||
| 				endIdx = i | ||||
| 				break | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	idx := rand.Intn(len(channels)) | ||||
| 	idx := rand.Intn(endIdx) | ||||
| 	return channels[idx], nil | ||||
| } | ||||
|   | ||||
| @@ -15,15 +15,15 @@ type Channel struct { | ||||
| 	CreatedTime        int64   `json:"created_time" gorm:"bigint"` | ||||
| 	TestTime           int64   `json:"test_time" gorm:"bigint"` | ||||
| 	ResponseTime       int     `json:"response_time"` // in milliseconds | ||||
| 	BaseURL            string  `json:"base_url" gorm:"column:base_url"` | ||||
| 	BaseURL            *string `json:"base_url" gorm:"column:base_url;default:''"` | ||||
| 	Other              string  `json:"other"` | ||||
| 	Balance            float64 `json:"balance"` // in USD | ||||
| 	BalanceUpdatedTime int64   `json:"balance_updated_time" gorm:"bigint"` | ||||
| 	Models             string  `json:"models"` | ||||
| 	Group              string  `json:"group" gorm:"type:varchar(32);default:'default'"` | ||||
| 	UsedQuota          int64   `json:"used_quota" gorm:"bigint;default:0"` | ||||
| 	ModelMapping       string  `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           int64   `json:"priority" gorm:"bigint;default:0"` | ||||
| 	ModelMapping       *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` | ||||
| 	Priority           *int64  `json:"priority" gorm:"bigint;default:0"` | ||||
| } | ||||
|  | ||||
| func GetAllChannels(startIdx int, num int, selectAll bool) ([]*Channel, error) { | ||||
| @@ -79,6 +79,27 @@ func BatchInsertChannels(channels []Channel) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetPriority() int64 { | ||||
| 	if channel.Priority == nil { | ||||
| 		return 0 | ||||
| 	} | ||||
| 	return *channel.Priority | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetBaseURL() string { | ||||
| 	if channel.BaseURL == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *channel.BaseURL | ||||
| } | ||||
|  | ||||
| func (channel *Channel) GetModelMapping() string { | ||||
| 	if channel.ModelMapping == nil { | ||||
| 		return "" | ||||
| 	} | ||||
| 	return *channel.ModelMapping | ||||
| } | ||||
|  | ||||
| func (channel *Channel) Insert() error { | ||||
| 	var err error | ||||
| 	err = DB.Create(channel).Error | ||||
|   | ||||
| @@ -174,7 +174,7 @@ const EditChannel = () => { | ||||
|       return; | ||||
|     } | ||||
|     let localInputs = inputs; | ||||
|     if (localInputs.base_url.endsWith('/')) { | ||||
|     if (localInputs.base_url && localInputs.base_url.endsWith('/')) { | ||||
|       localInputs.base_url = localInputs.base_url.slice(0, localInputs.base_url.length - 1); | ||||
|     } | ||||
|     if (localInputs.type === 3 && localInputs.other === '') { | ||||
| @@ -183,9 +183,6 @@ const EditChannel = () => { | ||||
|     if (localInputs.type === 18 && localInputs.other === '') { | ||||
|       localInputs.other = 'v2.1'; | ||||
|     } | ||||
|     if (localInputs.model_mapping === '') { | ||||
|       localInputs.model_mapping = '{}'; | ||||
|     } | ||||
|     let res; | ||||
|     localInputs.models = localInputs.models.join(','); | ||||
|     localInputs.group = localInputs.groups.join(','); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user