mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-19 01:56:37 +08:00
145 lines
3.9 KiB
Go
145 lines
3.9 KiB
Go
package imagen
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/pkg/errors"
|
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
|
"github.com/songquanpeng/one-api/relay/meta"
|
|
"github.com/songquanpeng/one-api/relay/model"
|
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
|
)
|
|
|
|
var ModelList = []string{
|
|
// -------------------------------------
|
|
// generate
|
|
// -------------------------------------
|
|
"imagen-3.0-generate-001", "imagen-3.0-generate-002",
|
|
"imagen-3.0-fast-generate-001",
|
|
// -------------------------------------
|
|
// edit
|
|
// -------------------------------------
|
|
"imagen-3.0-capability-001",
|
|
}
|
|
|
|
type Adaptor struct {
|
|
}
|
|
|
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
|
// No initialization needed
|
|
}
|
|
|
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
|
meta := meta.GetByContext(c)
|
|
|
|
if request.ResponseFormat != "b64_json" {
|
|
return nil, errors.New("only support b64_json response format")
|
|
}
|
|
if request.N <= 0 {
|
|
request.N = 1 // Default to 1 if not specified
|
|
}
|
|
|
|
switch meta.Mode {
|
|
case relaymode.ImagesGenerations:
|
|
return convertImageCreateRequest(request)
|
|
case relaymode.ImagesEdits:
|
|
switch c.ContentType() {
|
|
// case "application/json":
|
|
// return ConvertJsonImageEditRequest(c)
|
|
case "multipart/form-data":
|
|
return ConvertMultipartImageEditRequest(c)
|
|
default:
|
|
return nil, errors.New("unsupported content type for image edit")
|
|
}
|
|
default:
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
|
|
respBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return nil, openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
|
}
|
|
resp.Body.Close()
|
|
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))
|
|
|
|
switch meta.Mode {
|
|
case relaymode.ImagesEdits:
|
|
return HandleImageEdit(c, resp)
|
|
case relaymode.ImagesGenerations:
|
|
return nil, handleImageGeneration(c, resp, respBody)
|
|
default:
|
|
return nil, openai.ErrorWrapper(errors.New("unsupported mode"), "unsupported_mode", http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
func handleImageGeneration(c *gin.Context, resp *http.Response, respBody []byte) *model.ErrorWithStatusCode {
|
|
var imageResponse CreateImageResponse
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return openai.ErrorWrapper(errors.New(string(respBody)), "imagen_api_error", resp.StatusCode)
|
|
}
|
|
|
|
err := json.Unmarshal(respBody, &imageResponse)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
|
}
|
|
|
|
// Convert to OpenAI format
|
|
openaiResp := openai.ImageResponse{
|
|
Created: time.Now().Unix(),
|
|
Data: make([]openai.ImageData, 0, len(imageResponse.Predictions)),
|
|
}
|
|
|
|
for _, prediction := range imageResponse.Predictions {
|
|
openaiResp.Data = append(openaiResp.Data, openai.ImageData{
|
|
B64Json: prediction.BytesBase64Encoded,
|
|
})
|
|
}
|
|
|
|
respBytes, err := json.Marshal(openaiResp)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "marshal_response_failed", http.StatusInternalServerError)
|
|
}
|
|
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(http.StatusOK)
|
|
_, err = c.Writer.Write(respBytes)
|
|
if err != nil {
|
|
return openai.ErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *Adaptor) GetModelList() []string {
|
|
return ModelList
|
|
}
|
|
|
|
func (a *Adaptor) GetChannelName() string {
|
|
return "vertex_ai_imagen"
|
|
}
|
|
|
|
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
|
return CreateImageRequest{
|
|
Instances: []createImageInstance{
|
|
{
|
|
Prompt: request.Prompt,
|
|
},
|
|
},
|
|
Parameters: createImageParameters{
|
|
SampleCount: request.N,
|
|
},
|
|
}, nil
|
|
}
|