mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: support cogview-3
This commit is contained in:
		@@ -62,8 +62,8 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"text-search-ada-doc-001": 10,
 | 
			
		||||
	"text-moderation-stable":  0.1,
 | 
			
		||||
	"text-moderation-latest":  0.1,
 | 
			
		||||
	"dall-e-2":                8,  // $0.016 - $0.020 / image
 | 
			
		||||
	"dall-e-3":                20, // $0.040 - $0.120 / image
 | 
			
		||||
	"dall-e-2":                0.02 * USD, // $0.016 - $0.020 / image
 | 
			
		||||
	"dall-e-3":                0.04 * USD, // $0.040 - $0.120 / image
 | 
			
		||||
	// https://www.anthropic.com/api#pricing
 | 
			
		||||
	"claude-instant-1.2":       0.8 / 1000 * USD,
 | 
			
		||||
	"claude-2.0":               8.0 / 1000 * USD,
 | 
			
		||||
@@ -96,14 +96,15 @@ var ModelRatio = map[string]float64{
 | 
			
		||||
	"gemini-1.0-pro-001":        1,
 | 
			
		||||
	"gemini-1.5-pro":            1,
 | 
			
		||||
	// https://open.bigmodel.cn/pricing
 | 
			
		||||
	"glm-4":                     0.1 * RMB,
 | 
			
		||||
	"glm-4v":                    0.1 * RMB,
 | 
			
		||||
	"glm-3-turbo":               0.005 * RMB,
 | 
			
		||||
	"embedding-2":               0.0005 * RMB,
 | 
			
		||||
	"chatglm_turbo":             0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_pro":               0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":               0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":              0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"glm-4":         0.1 * RMB,
 | 
			
		||||
	"glm-4v":        0.1 * RMB,
 | 
			
		||||
	"glm-3-turbo":   0.005 * RMB,
 | 
			
		||||
	"embedding-2":   0.0005 * RMB,
 | 
			
		||||
	"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_pro":   0.7143, // ¥0.01 / 1k tokens
 | 
			
		||||
	"chatglm_std":   0.3572, // ¥0.005 / 1k tokens
 | 
			
		||||
	"chatglm_lite":  0.1429, // ¥0.002 / 1k tokens
 | 
			
		||||
	"cogview-3":     0.25 * RMB,
 | 
			
		||||
	// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
 | 
			
		||||
	"qwen-turbo":                0.5715, // ¥0.008 / 1k tokens
 | 
			
		||||
	"qwen-plus":                 1.4286, // ¥0.02 / 1k tokens
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										44
									
								
								relay/channel/openai/image.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								relay/channel/openai/image.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
			
		||||
package openai
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/model"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var imageResponse ImageResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &imageResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -149,37 +149,3 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &textResponse.Usage
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
 | 
			
		||||
	var imageResponse ImageResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &imageResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -32,13 +32,16 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case constant.RelayModeImagesGenerations:
 | 
			
		||||
		return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
 | 
			
		||||
	case constant.RelayModeEmbeddings:
 | 
			
		||||
		return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
 | 
			
		||||
	}
 | 
			
		||||
	a.SetVersionByModeName(meta.ActualModelName)
 | 
			
		||||
	if a.APIVersion == "v4" {
 | 
			
		||||
		return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
 | 
			
		||||
	}
 | 
			
		||||
	if meta.Mode == constant.RelayModeEmbeddings {
 | 
			
		||||
		return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
 | 
			
		||||
	}
 | 
			
		||||
	method := "invoke"
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		method = "sse-invoke"
 | 
			
		||||
@@ -81,7 +84,12 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
 | 
			
		||||
	if request == nil {
 | 
			
		||||
		return nil, errors.New("request is nil")
 | 
			
		||||
	}
 | 
			
		||||
	return request, nil
 | 
			
		||||
	newRequest := ImageRequest{
 | 
			
		||||
		Model:  request.Model,
 | 
			
		||||
		Prompt: request.Prompt,
 | 
			
		||||
		UserId: request.User,
 | 
			
		||||
	}
 | 
			
		||||
	return newRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
 | 
			
		||||
@@ -98,10 +106,17 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
 | 
			
		||||
	switch meta.Mode {
 | 
			
		||||
	case constant.RelayModeEmbeddings:
 | 
			
		||||
		err, usage = EmbeddingsHandler(c, resp)
 | 
			
		||||
		return
 | 
			
		||||
	case constant.RelayModeImagesGenerations:
 | 
			
		||||
		err, usage = openai.ImageHandler(c, resp)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if a.APIVersion == "v4" {
 | 
			
		||||
		return a.DoResponseV4(c, resp, meta)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if meta.IsStream {
 | 
			
		||||
		err, usage = StreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -3,4 +3,5 @@ package zhipu
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
 | 
			
		||||
	"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
 | 
			
		||||
	"cogview-3",
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -62,3 +62,9 @@ type EmbeddingData struct {
 | 
			
		||||
	Object    string    `json:"object"`
 | 
			
		||||
	Embedding []float64 `json:"embedding"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ImageRequest struct {
 | 
			
		||||
	Model  string `json:"model"`
 | 
			
		||||
	Prompt string `json:"prompt"`
 | 
			
		||||
	UserId string `json:"user_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
package constant
 | 
			
		||||
 | 
			
		||||
var DalleSizeRatios = map[string]map[string]float64{
 | 
			
		||||
var ImageSizeRatios = map[string]map[string]float64{
 | 
			
		||||
	"dall-e-2": {
 | 
			
		||||
		"256x256":   1,
 | 
			
		||||
		"512x512":   1.125,
 | 
			
		||||
@@ -11,7 +11,14 @@ var DalleSizeRatios = map[string]map[string]float64{
 | 
			
		||||
		"1024x1792": 2,
 | 
			
		||||
		"1792x1024": 2,
 | 
			
		||||
	},
 | 
			
		||||
	"stable-diffusion-xl": {
 | 
			
		||||
	"ali-stable-diffusion-xl": {
 | 
			
		||||
		"512x1024":  1,
 | 
			
		||||
		"1024x768":  1,
 | 
			
		||||
		"1024x1024": 1,
 | 
			
		||||
		"576x1024":  1,
 | 
			
		||||
		"1024x576":  1,
 | 
			
		||||
	},
 | 
			
		||||
	"ali-stable-diffusion-v1.5": {
 | 
			
		||||
		"512x1024":  1,
 | 
			
		||||
		"1024x768":  1,
 | 
			
		||||
		"1024x1024": 1,
 | 
			
		||||
@@ -25,17 +32,20 @@ var DalleSizeRatios = map[string]map[string]float64{
 | 
			
		||||
	},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DalleGenerationImageAmounts = map[string][2]int{
 | 
			
		||||
	"dall-e-2":            {1, 10},
 | 
			
		||||
	"dall-e-3":            {1, 1}, // OpenAI allows n=1 currently.
 | 
			
		||||
	"stable-diffusion-xl": {1, 4}, // Ali
 | 
			
		||||
	"wanx-v1":             {1, 4}, // Ali
 | 
			
		||||
var ImageGenerationAmounts = map[string][2]int{
 | 
			
		||||
	"dall-e-2":                  {1, 10},
 | 
			
		||||
	"dall-e-3":                  {1, 1}, // OpenAI allows n=1 currently.
 | 
			
		||||
	"ali-stable-diffusion-xl":   {1, 4}, // Ali
 | 
			
		||||
	"ali-stable-diffusion-v1.5": {1, 4}, // Ali
 | 
			
		||||
	"wanx-v1":                   {1, 4}, // Ali
 | 
			
		||||
	"cogview-3":                 {1, 1},
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var DalleImagePromptLengthLimitations = map[string]int{
 | 
			
		||||
	"dall-e-2":            1000,
 | 
			
		||||
	"dall-e-3":            4000,
 | 
			
		||||
	"stable-diffusion-xl": 4000,
 | 
			
		||||
	"wanx-v1":             4000,
 | 
			
		||||
	"cogview-3":           833,
 | 
			
		||||
var ImagePromptLengthLimitations = map[string]int{
 | 
			
		||||
	"dall-e-2":                  1000,
 | 
			
		||||
	"dall-e-3":                  4000,
 | 
			
		||||
	"ali-stable-diffusion-xl":   4000,
 | 
			
		||||
	"ali-stable-diffusion-v1.5": 4000,
 | 
			
		||||
	"wanx-v1":                   4000,
 | 
			
		||||
	"cogview-3":                 833,
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -54,9 +54,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e
 | 
			
		||||
	return imageRequest, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidImageSize(model string, size string) bool {
 | 
			
		||||
	if model == "cogview-3" {
 | 
			
		||||
		return true
 | 
			
		||||
	}
 | 
			
		||||
	_, ok := constant.ImageSizeRatios[model][size]
 | 
			
		||||
	return ok
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getImageSizeRatio(model string, size string) float64 {
 | 
			
		||||
	ratio, ok := constant.ImageSizeRatios[model][size]
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return 1
 | 
			
		||||
	}
 | 
			
		||||
	return ratio
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
 | 
			
		||||
	// model validation
 | 
			
		||||
	_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
 | 
			
		||||
	hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if !hasValidSize {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
@@ -64,7 +80,7 @@ func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.Rela
 | 
			
		||||
	if imageRequest.Prompt == "" {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
 | 
			
		||||
	if len(imageRequest.Prompt) > constant.ImagePromptLengthLimitations[imageRequest.Model] {
 | 
			
		||||
		return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
 | 
			
		||||
	}
 | 
			
		||||
	// Number of generated images validation
 | 
			
		||||
@@ -81,10 +97,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
 | 
			
		||||
	if imageRequest == nil {
 | 
			
		||||
		return 0, errors.New("imageRequest is nil")
 | 
			
		||||
	}
 | 
			
		||||
	imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
 | 
			
		||||
	if !hasValidSize {
 | 
			
		||||
		return 0, fmt.Errorf("size not supported for this image model: %s", imageRequest.Size)
 | 
			
		||||
	}
 | 
			
		||||
	imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
 | 
			
		||||
	if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
 | 
			
		||||
		if imageRequest.Size == "1024x1024" {
 | 
			
		||||
			imageCostRatio *= 2
 | 
			
		||||
 
 | 
			
		||||
@@ -20,12 +20,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func isWithinRange(element string, value int) bool {
 | 
			
		||||
	if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
 | 
			
		||||
	if _, ok := constant.ImageGenerationAmounts[element]; !ok {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	min := constant.DalleGenerationImageAmounts[element][0]
 | 
			
		||||
	max := constant.DalleGenerationImageAmounts[element][1]
 | 
			
		||||
 | 
			
		||||
	min := constant.ImageGenerationAmounts[element][0]
 | 
			
		||||
	max := constant.ImageGenerationAmounts[element][1]
 | 
			
		||||
	return value >= min && value <= max
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -81,7 +80,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		jsonStr, err := json.Marshal(finalRequest)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user