mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	
							
								
								
									
										30
									
								
								common/group-ratio.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								common/group-ratio.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,30 @@
 | 
			
		||||
package common
 | 
			
		||||
 | 
			
		||||
import "encoding/json"
 | 
			
		||||
 | 
			
		||||
var GroupRatio = map[string]float64{
 | 
			
		||||
	"default": 1,
 | 
			
		||||
	"vip":     1,
 | 
			
		||||
	"svip":    1,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GroupRatio2JSONString() string {
 | 
			
		||||
	jsonBytes, err := json.Marshal(GroupRatio)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		SysError("Error marshalling model ratio: " + err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	return string(jsonBytes)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func UpdateGroupRatioByJSONString(jsonStr string) error {
 | 
			
		||||
	return json.Unmarshal([]byte(jsonStr), &GroupRatio)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetGroupRatio(name string) float64 {
 | 
			
		||||
	ratio, ok := GroupRatio[name]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		SysError("Group ratio not found: " + name)
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
	return ratio
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										19
									
								
								controller/group.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								controller/group.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
package controller
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetGroups(c *gin.Context) {
 | 
			
		||||
	groupNames := make([]string, 0)
 | 
			
		||||
	for groupName, _ := range common.GroupRatio {
 | 
			
		||||
		groupNames = append(groupNames, groupName)
 | 
			
		||||
	}
 | 
			
		||||
	c.JSON(http.StatusOK, gin.H{
 | 
			
		||||
		"success": true,
 | 
			
		||||
		"message": "",
 | 
			
		||||
		"data":    groupNames,
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -140,6 +140,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	channelType := c.GetInt("channel")
 | 
			
		||||
	tokenId := c.GetInt("token_id")
 | 
			
		||||
	consumeQuota := c.GetBool("consume_quota")
 | 
			
		||||
	group := c.GetString("group")
 | 
			
		||||
	var textRequest GeneralOpenAIRequest
 | 
			
		||||
	if consumeQuota || channelType == common.ChannelTypeAzure || channelType == common.ChannelTypePaLM {
 | 
			
		||||
		err := common.UnmarshalBodyReusable(c, &textRequest)
 | 
			
		||||
@@ -194,7 +195,7 @@ func relayHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
 | 
			
		||||
	if textRequest.MaxTokens != 0 {
 | 
			
		||||
		preConsumedTokens = promptTokens + textRequest.MaxTokens
 | 
			
		||||
	}
 | 
			
		||||
	ratio := common.GetModelRatio(textRequest.Model)
 | 
			
		||||
	ratio := common.GetModelRatio(textRequest.Model) * common.GetGroupRatio(group)
 | 
			
		||||
	preConsumedQuota := int(float64(preConsumedTokens) * ratio)
 | 
			
		||||
	if consumeQuota {
 | 
			
		||||
		err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota)
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,9 @@ type ModelRequest struct {
 | 
			
		||||
 | 
			
		||||
func Distribute() func(c *gin.Context) {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		userId := c.GetInt("id")
 | 
			
		||||
		userGroup, _ := model.GetUserGroup(userId)
 | 
			
		||||
		c.Set("group", userGroup)
 | 
			
		||||
		var channel *model.Channel
 | 
			
		||||
		channelId, ok := c.Get("channelId")
 | 
			
		||||
		if ok {
 | 
			
		||||
@@ -70,8 +73,6 @@ func Distribute() func(c *gin.Context) {
 | 
			
		||||
					modelRequest.Model = "text-moderation-stable"
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			userId := c.GetInt("id")
 | 
			
		||||
			userGroup, _ := model.GetUserGroup(userId)
 | 
			
		||||
			channel, err = model.GetRandomSatisfiedChannel(userGroup, modelRequest.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.JSON(200, gin.H{
 | 
			
		||||
 
 | 
			
		||||
@@ -58,6 +58,7 @@ func InitOptionMap() {
 | 
			
		||||
	common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold)
 | 
			
		||||
	common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota)
 | 
			
		||||
	common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
 | 
			
		||||
	common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
 | 
			
		||||
	common.OptionMap["TopUpLink"] = common.TopUpLink
 | 
			
		||||
	common.OptionMapRWMutex.Unlock()
 | 
			
		||||
	loadOptionsFromDatabase()
 | 
			
		||||
@@ -177,6 +178,8 @@ func updateOptionMap(key string, value string) (err error) {
 | 
			
		||||
		common.PreConsumedQuota, _ = strconv.Atoi(value)
 | 
			
		||||
	case "ModelRatio":
 | 
			
		||||
		err = common.UpdateModelRatioByJSONString(value)
 | 
			
		||||
	case "GroupRatio":
 | 
			
		||||
		err = common.UpdateGroupRatioByJSONString(value)
 | 
			
		||||
	case "TopUpLink":
 | 
			
		||||
		common.TopUpLink = value
 | 
			
		||||
	case "ChannelDisableThreshold":
 | 
			
		||||
 
 | 
			
		||||
@@ -98,5 +98,10 @@ func SetApiRouter(router *gin.Engine) {
 | 
			
		||||
		logRoute.GET("/search", middleware.AdminAuth(), controller.SearchAllLogs)
 | 
			
		||||
		logRoute.GET("/self", middleware.UserAuth(), controller.GetUserLogs)
 | 
			
		||||
		logRoute.GET("/self/search", middleware.UserAuth(), controller.SearchUserLogs)
 | 
			
		||||
		groupRoute := apiRouter.Group("/group")
 | 
			
		||||
		groupRoute.Use(middleware.AdminAuth())
 | 
			
		||||
		{
 | 
			
		||||
			groupRoute.GET("/", controller.GetGroups)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,7 @@ const SystemSetting = () => {
 | 
			
		||||
    QuotaRemindThreshold: 0,
 | 
			
		||||
    PreConsumedQuota: 0,
 | 
			
		||||
    ModelRatio: '',
 | 
			
		||||
    GroupRatio: '',
 | 
			
		||||
    TopUpLink: '',
 | 
			
		||||
    AutomaticDisableChannelEnabled: '',
 | 
			
		||||
    ChannelDisableThreshold: 0,
 | 
			
		||||
@@ -101,6 +102,7 @@ const SystemSetting = () => {
 | 
			
		||||
      name === 'QuotaRemindThreshold' ||
 | 
			
		||||
      name === 'PreConsumedQuota' ||
 | 
			
		||||
      name === 'ModelRatio' ||
 | 
			
		||||
      name === 'GroupRatio' ||
 | 
			
		||||
      name === 'TopUpLink'
 | 
			
		||||
    ) {
 | 
			
		||||
      setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
@@ -131,6 +133,13 @@ const SystemSetting = () => {
 | 
			
		||||
      }
 | 
			
		||||
      await updateOption('ModelRatio', inputs.ModelRatio);
 | 
			
		||||
    }
 | 
			
		||||
    if (originInputs['GroupRatio'] !== inputs.GroupRatio) {
 | 
			
		||||
      if (!verifyJSON(inputs.GroupRatio)) {
 | 
			
		||||
        showError('分组倍率不是合法的 JSON 字符串');
 | 
			
		||||
        return;
 | 
			
		||||
      }
 | 
			
		||||
      await updateOption('GroupRatio', inputs.GroupRatio);
 | 
			
		||||
    }
 | 
			
		||||
    if (originInputs['TopUpLink'] !== inputs.TopUpLink) {
 | 
			
		||||
      await updateOption('TopUpLink', inputs.TopUpLink);
 | 
			
		||||
    }
 | 
			
		||||
@@ -329,6 +338,17 @@ const SystemSetting = () => {
 | 
			
		||||
              placeholder='为一个 JSON 文本,键为模型名称,值为倍率'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Group widths='equal'>
 | 
			
		||||
            <Form.TextArea
 | 
			
		||||
              label='分组倍率'
 | 
			
		||||
              name='GroupRatio'
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              style={{ minHeight: 250, fontFamily: 'JetBrains Mono, Consolas' }}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              value={inputs.GroupRatio}
 | 
			
		||||
              placeholder='为一个 JSON 文本,键为分组名称,值为倍率'
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Group>
 | 
			
		||||
          <Form.Button onClick={submitOperationConfig}>保存运营设置</Form.Button>
 | 
			
		||||
          <Divider />
 | 
			
		||||
          <Header as='h3'>
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,10 @@ export function renderText(text, limit) {
 | 
			
		||||
export function renderGroup(group) {
 | 
			
		||||
  if (group === "") {
 | 
			
		||||
    return <Label>default</Label>
 | 
			
		||||
  } else if (group === "vip" || group === "pro") {
 | 
			
		||||
    return <Label color='yellow'>{group}</Label>
 | 
			
		||||
  } else if (group === "svip" || group === "premium") {
 | 
			
		||||
    return <Label color='red'>{group}</Label>
 | 
			
		||||
  }
 | 
			
		||||
  return <Label>{group}</Label>
 | 
			
		||||
}
 | 
			
		||||
@@ -21,6 +21,7 @@ const EditChannel = () => {
 | 
			
		||||
  const [batch, setBatch] = useState(false);
 | 
			
		||||
  const [inputs, setInputs] = useState(originInputs);
 | 
			
		||||
  const [modelOptions, setModelOptions] = useState([]);
 | 
			
		||||
  const [groupOptions, setGroupOptions] = useState([]);
 | 
			
		||||
  const [basicModels, setBasicModels] = useState([]);
 | 
			
		||||
  const [fullModels, setFullModels] = useState([]);
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
@@ -58,11 +59,25 @@ const EditChannel = () => {
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const fetchGroups = async () => {
 | 
			
		||||
    try {
 | 
			
		||||
      let res = await API.get(`/api/group`);
 | 
			
		||||
      setGroupOptions(res.data.data.map((group) => ({
 | 
			
		||||
        key: group,
 | 
			
		||||
        text: group,
 | 
			
		||||
        value: group,
 | 
			
		||||
      })));
 | 
			
		||||
    } catch (error) {
 | 
			
		||||
      showError(error.message);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    if (isEdit) {
 | 
			
		||||
      loadChannel().then();
 | 
			
		||||
    }
 | 
			
		||||
    fetchModels().then();
 | 
			
		||||
    fetchGroups().then();
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const submit = async () => {
 | 
			
		||||
@@ -167,13 +182,19 @@ const EditChannel = () => {
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
            <Form.Input
 | 
			
		||||
            <Form.Dropdown
 | 
			
		||||
              label='分组'
 | 
			
		||||
              placeholder={'请选择分组'}
 | 
			
		||||
              name='group'
 | 
			
		||||
              placeholder={'请输入分组'}
 | 
			
		||||
              fluid
 | 
			
		||||
              search
 | 
			
		||||
              selection
 | 
			
		||||
              allowAdditions
 | 
			
		||||
              additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.group}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
              options={groupOptions}
 | 
			
		||||
            />
 | 
			
		||||
          </Form.Field>
 | 
			
		||||
          <Form.Field>
 | 
			
		||||
 
 | 
			
		||||
@@ -17,11 +17,24 @@ const EditUser = () => {
 | 
			
		||||
    quota: 0,
 | 
			
		||||
    group: 'default'
 | 
			
		||||
  });
 | 
			
		||||
  const [groupOptions, setGroupOptions] = useState([]);
 | 
			
		||||
  const { username, display_name, password, github_id, wechat_id, email, quota, group } =
 | 
			
		||||
    inputs;
 | 
			
		||||
  const handleInputChange = (e, { name, value }) => {
 | 
			
		||||
    setInputs((inputs) => ({ ...inputs, [name]: value }));
 | 
			
		||||
  };
 | 
			
		||||
  const fetchGroups = async () => {
 | 
			
		||||
    try {
 | 
			
		||||
      let res = await API.get(`/api/group`);
 | 
			
		||||
      setGroupOptions(res.data.data.map((group) => ({
 | 
			
		||||
        key: group,
 | 
			
		||||
        text: group,
 | 
			
		||||
        value: group,
 | 
			
		||||
      })));
 | 
			
		||||
    } catch (error) {
 | 
			
		||||
      showError(error.message);
 | 
			
		||||
    }
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  const loadUser = async () => {
 | 
			
		||||
    let res = undefined;
 | 
			
		||||
@@ -41,6 +54,9 @@ const EditUser = () => {
 | 
			
		||||
  };
 | 
			
		||||
  useEffect(() => {
 | 
			
		||||
    loadUser().then();
 | 
			
		||||
    if (userId) {
 | 
			
		||||
      fetchGroups().then();
 | 
			
		||||
    }
 | 
			
		||||
  }, []);
 | 
			
		||||
 | 
			
		||||
  const submit = async () => {
 | 
			
		||||
@@ -101,13 +117,19 @@ const EditUser = () => {
 | 
			
		||||
          {
 | 
			
		||||
            userId && <>
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
                <Form.Input
 | 
			
		||||
                <Form.Dropdown
 | 
			
		||||
                  label='分组'
 | 
			
		||||
                  placeholder={'请选择分组'}
 | 
			
		||||
                  name='group'
 | 
			
		||||
                  placeholder={'请输入用户分组'}
 | 
			
		||||
                  fluid
 | 
			
		||||
                  search
 | 
			
		||||
                  selection
 | 
			
		||||
                  allowAdditions
 | 
			
		||||
                  additionLabel={'请在系统设置页面编辑分组倍率以添加新的分组:'}
 | 
			
		||||
                  onChange={handleInputChange}
 | 
			
		||||
                  value={group}
 | 
			
		||||
                  value={inputs.group}
 | 
			
		||||
                  autoComplete='new-password'
 | 
			
		||||
                  options={groupOptions}
 | 
			
		||||
                />
 | 
			
		||||
              </Form.Field>
 | 
			
		||||
              <Form.Field>
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user