mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-11 19:03:43 +08:00
feat: support vertex imagen3
This commit is contained in:
107
relay/adaptor/vertexai/imagen/adaptor.go
Normal file
107
relay/adaptor/vertexai/imagen/adaptor.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package imagen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"imagen-3.0-generate-001",
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
if request.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("only support b64_json response format")
|
||||
}
|
||||
if request.N <= 0 {
|
||||
return nil, errors.New("n must be greater than 0")
|
||||
}
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return convertImageCreateRequest(request)
|
||||
case relaymode.ImagesEdits:
|
||||
return nil, errors.New("not implemented")
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||
return CreateImageRequest{
|
||||
Instances: []createImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: createImageParameters{
|
||||
SampleCount: request.N,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.Wrap(err, "failed to read response body"),
|
||||
"read_response_body",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
|
||||
"upstream_response",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
imagenResp := new(CreateImageResponse)
|
||||
if err := json.Unmarshal(respBody, imagenResp); err != nil {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.Wrap(err, "failed to decode response body"),
|
||||
"unmarshal_upstream_response",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
if len(imagenResp.Predictions) == 0 {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.New("empty predictions"),
|
||||
"empty_predictions",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
oaiResp := openai.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
for _, prediction := range imagenResp.Predictions {
|
||||
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, oaiResp)
|
||||
return nil, nil
|
||||
}
|
||||
23
relay/adaptor/vertexai/imagen/model.go
Normal file
23
relay/adaptor/vertexai/imagen/model.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package imagen
|
||||
|
||||
type CreateImageRequest struct {
|
||||
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
|
||||
Parameters createImageParameters `json:"parameters" binding:"required"`
|
||||
}
|
||||
|
||||
type createImageInstance struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type createImageParameters struct {
|
||||
SampleCount int `json:"sample_count" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
type CreateImageResponse struct {
|
||||
Predictions []createImageResponsePrediction `json:"predictions"`
|
||||
}
|
||||
|
||||
type createImageResponsePrediction struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
}
|
||||
Reference in New Issue
Block a user