feat: support cogview-3

This commit is contained in:
JustSong
2024-04-06 00:30:08 +08:00
parent 0cb224e62e
commit 880e12c855
9 changed files with 127 additions and 73 deletions

View File

@@ -0,0 +1,44 @@
package openai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -149,37 +149,3 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
return nil, &textResponse.Usage
}
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -32,13 +32,16 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.Mode {
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
a.SetVersionByModeName(meta.ActualModelName)
if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
}
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
@@ -81,7 +84,12 @@ func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error)
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
newRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
UserId: request.User,
}
return newRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
@@ -98,10 +106,17 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingsHandler(c, resp)
return
case constant.RelayModeImagesGenerations:
err, usage = openai.ImageHandler(c, resp)
return
}
if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta)
}
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {

View File

@@ -3,4 +3,5 @@ package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
"cogview-3",
}

View File

@@ -62,3 +62,9 @@ type EmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
}
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
UserId string `json:"user_id,omitempty"`
}