feat: support vertex imagen3

This commit is contained in:
Laisky.Cai
2025-01-12 04:18:57 +00:00
parent bfe28fc1f8
commit feacea0321
30 changed files with 227 additions and 49 deletions

View File

@@ -1,18 +1,20 @@
package vertexai
import (
"errors"
"fmt"
"io"
"net/http"
"slices"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor"
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
relaymodel "github.com/songquanpeng/one-api/relay/model"
relayModel "github.com/songquanpeng/one-api/relay/model"
)
var _ adaptor.Adaptor = new(Adaptor)
@@ -24,14 +26,27 @@ type Adaptor struct{}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
meta := meta.GetByContext(c)
if request.ResponseFormat != "b64_json" {
return nil, errors.New("only support b64_json response format")
}
adaptor := GetAdaptor(request.Model)
adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil {
return nil, errors.New("adaptor not found")
return nil, errors.Errorf("cannot found vertex image adaptor for model %s", meta.ActualModelName)
}
return adaptor.ConvertImageRequest(c, request)
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
meta := meta.GetByContext(c)
adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil {
return nil, errors.Errorf("cannot found vertex chat adaptor for model %s", meta.ActualModelName)
}
return adaptor.ConvertRequest(c, relayMode, request)
@@ -40,9 +55,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
adaptor := GetAdaptor(meta.ActualModelName)
if adaptor == nil {
return nil, &relaymodel.ErrorWithStatusCode{
return nil, &relayModel.ErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: relaymodel.Error{
Error: relayModel.Error{
Message: "adaptor not found",
},
}
@@ -60,14 +75,19 @@ func (a *Adaptor) GetChannelName() string {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
suffix := ""
if strings.HasPrefix(meta.ActualModelName, "gemini") {
var suffix string
switch {
case strings.HasPrefix(meta.ActualModelName, "gemini"):
if meta.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
} else {
case slices.Contains(imagen.ModelList, meta.ActualModelName):
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/imagen-3.0-generate-001:predict",
meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region,
), nil
default:
if meta.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
@@ -85,6 +105,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
suffix,
), nil
}
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
meta.Config.Region,
@@ -105,13 +126,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
return nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -50,6 +50,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return req, nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("not support image request")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = anthropic.StreamHandler(c, resp)

View File

@@ -35,6 +35,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return geminiRequest, nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
return nil, errors.New("not support image request")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string

View File

@@ -0,0 +1,107 @@
package imagen
import (
"encoding/json"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
var ModelList = []string{
"imagen-3.0-generate-001",
}
type Adaptor struct {
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
meta := meta.GetByContext(c)
if request.ResponseFormat != "b64_json" {
return nil, errors.New("only support b64_json response format")
}
if request.N <= 0 {
return nil, errors.New("n must be greater than 0")
}
switch meta.Mode {
case relaymode.ImagesGenerations:
return convertImageCreateRequest(request)
case relaymode.ImagesEdits:
return nil, errors.New("not implemented")
default:
return nil, errors.New("not implemented")
}
}
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
return CreateImageRequest{
Instances: []createImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: createImageParameters{
SampleCount: request.N,
},
}, nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to read response body"),
"read_response_body",
http.StatusInternalServerError,
)
}
if resp.StatusCode != http.StatusOK {
return nil, openai.ErrorWrapper(
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
"upstream_response",
http.StatusInternalServerError,
)
}
imagenResp := new(CreateImageResponse)
if err := json.Unmarshal(respBody, imagenResp); err != nil {
return nil, openai.ErrorWrapper(
errors.Wrap(err, "failed to decode response body"),
"unmarshal_upstream_response",
http.StatusInternalServerError,
)
}
if len(imagenResp.Predictions) == 0 {
return nil, openai.ErrorWrapper(
errors.New("empty predictions"),
"empty_predictions",
http.StatusInternalServerError,
)
}
oaiResp := openai.ImageResponse{
Created: time.Now().Unix(),
}
for _, prediction := range imagenResp.Predictions {
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
c.JSON(http.StatusOK, oaiResp)
return nil, nil
}

View File

@@ -0,0 +1,23 @@
package imagen
type CreateImageRequest struct {
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
Parameters createImageParameters `json:"parameters" binding:"required"`
}
type createImageInstance struct {
Prompt string `json:"prompt"`
}
type createImageParameters struct {
SampleCount int `json:"sample_count" binding:"required,min=1"`
}
type CreateImageResponse struct {
Predictions []createImageResponsePrediction `json:"predictions"`
}
type createImageResponsePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -13,8 +14,9 @@ import (
type VertexAIModelType int
const (
VerterAIClaude VertexAIModelType = iota + 1
VerterAIGemini
VertexAIClaude VertexAIModelType = iota + 1
VertexAIGemini
VertexAIImagen
)
var modelMapping = map[string]VertexAIModelType{}
@@ -23,27 +25,35 @@ var modelList = []string{}
func init() {
modelList = append(modelList, claude.ModelList...)
for _, model := range claude.ModelList {
modelMapping[model] = VerterAIClaude
modelMapping[model] = VertexAIClaude
}
modelList = append(modelList, gemini.ModelList...)
for _, model := range gemini.ModelList {
modelMapping[model] = VerterAIGemini
modelMapping[model] = VertexAIGemini
}
modelList = append(modelList, imagen.ModelList...)
for _, model := range imagen.ModelList {
modelMapping[model] = VertexAIImagen
}
}
type innerAIAdapter interface {
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
}
func GetAdaptor(model string) innerAIAdapter {
adaptorType := modelMapping[model]
switch adaptorType {
case VerterAIClaude:
case VertexAIClaude:
return &claude.Adaptor{}
case VerterAIGemini:
case VertexAIGemini:
return &gemini.Adaptor{}
case VertexAIImagen:
return &imagen.Adaptor{}
default:
return nil
}