mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
feat: support vertex imagen3
This commit is contained in:
parent
3915ce9814
commit
009337ccf3
@ -21,4 +21,5 @@ const (
|
|||||||
AvailableModels = "available_models"
|
AvailableModels = "available_models"
|
||||||
KeyRequestBody = "key_request_body"
|
KeyRequestBody = "key_request_body"
|
||||||
SystemPrompt = "system_prompt"
|
SystemPrompt = "system_prompt"
|
||||||
|
Meta = "meta"
|
||||||
)
|
)
|
||||||
|
@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return aiProxyLibraryRequest, nil
|
return aiProxyLibraryRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -19,7 +19,7 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ import (
|
|||||||
type Adaptor struct{}
|
type Adaptor struct{}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return convertedRequest, nil
|
return convertedRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -65,7 +65,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -13,7 +13,7 @@ type Adaptor interface {
|
|||||||
GetRequestURL(meta *meta.Meta) (string, error)
|
GetRequestURL(meta *meta.Meta) (string, error)
|
||||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||||
ConvertImageRequest(request *model.ImageRequest) (any, error)
|
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||||
GetModelList() []string
|
GetModelList() []string
|
||||||
|
@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(*request), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
return nil, errors.Errorf("not implement")
|
return nil, errors.Errorf("not implement")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,29 @@ type Adaptor struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ConvertImageRequest implements adaptor.Adaptor.
|
// ConvertImageRequest implements adaptor.Adaptor.
|
||||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
return nil, errors.New("should call replicate.ConvertImageRequest instead")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
meta := meta.GetByContext(c)
|
||||||
|
|
||||||
|
if request.ResponseFormat != "b64_json" {
|
||||||
|
return nil, errors.New("only support b64_json response format")
|
||||||
|
}
|
||||||
|
if request.N != 1 && request.N != 0 {
|
||||||
|
return nil, errors.New("only support N=1")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
return convertImageCreateRequest(request)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||||
return DrawImageRequest{
|
return DrawImageRequest{
|
||||||
Input: ImageInput{
|
Input: ImageInput{
|
||||||
Steps: 25,
|
Steps: 25,
|
||||||
|
@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return tencentRequest, nil
|
return tencentRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
package vertexai
|
package vertexai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relayModel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ adaptor.Adaptor = new(Adaptor)
|
var _ adaptor.Adaptor = new(Adaptor)
|
||||||
@ -24,14 +26,27 @@ type Adaptor struct{}
|
|||||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
meta := meta.GetByContext(c)
|
||||||
return nil, errors.New("request is nil")
|
|
||||||
|
if request.ResponseFormat != "b64_json" {
|
||||||
|
return nil, errors.New("only support b64_json response format")
|
||||||
}
|
}
|
||||||
|
|
||||||
adaptor := GetAdaptor(request.Model)
|
adaptor := GetAdaptor(meta.ActualModelName)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return nil, errors.New("adaptor not found")
|
return nil, errors.Errorf("cannot found vertex image adaptor for model %s", meta.ActualModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return adaptor.ConvertImageRequest(c, request)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
meta := meta.GetByContext(c)
|
||||||
|
|
||||||
|
adaptor := GetAdaptor(meta.ActualModelName)
|
||||||
|
if adaptor == nil {
|
||||||
|
return nil, errors.Errorf("cannot found vertex chat adaptor for model %s", meta.ActualModelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
return adaptor.ConvertRequest(c, relayMode, request)
|
return adaptor.ConvertRequest(c, relayMode, request)
|
||||||
@ -40,9 +55,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
adaptor := GetAdaptor(meta.ActualModelName)
|
adaptor := GetAdaptor(meta.ActualModelName)
|
||||||
if adaptor == nil {
|
if adaptor == nil {
|
||||||
return nil, &relaymodel.ErrorWithStatusCode{
|
return nil, &relayModel.ErrorWithStatusCode{
|
||||||
StatusCode: http.StatusInternalServerError,
|
StatusCode: http.StatusInternalServerError,
|
||||||
Error: relaymodel.Error{
|
Error: relayModel.Error{
|
||||||
Message: "adaptor not found",
|
Message: "adaptor not found",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -60,14 +75,19 @@ func (a *Adaptor) GetChannelName() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||||
suffix := ""
|
var suffix string
|
||||||
if strings.HasPrefix(meta.ActualModelName, "gemini") {
|
switch {
|
||||||
|
case strings.HasPrefix(meta.ActualModelName, "gemini"):
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
suffix = "streamGenerateContent?alt=sse"
|
suffix = "streamGenerateContent?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
suffix = "generateContent"
|
suffix = "generateContent"
|
||||||
}
|
}
|
||||||
} else {
|
case slices.Contains(imagen.ModelList, meta.ActualModelName):
|
||||||
|
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/imagen-3.0-generate-001:predict",
|
||||||
|
meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region,
|
||||||
|
), nil
|
||||||
|
default:
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
suffix = "streamRawPredict?alt=sse"
|
suffix = "streamRawPredict?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
@ -85,6 +105,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
|||||||
suffix,
|
suffix,
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf(
|
return fmt.Sprintf(
|
||||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||||
meta.Config.Region,
|
meta.Config.Region,
|
||||||
@ -105,13 +126,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
|
||||||
if request == nil {
|
|
||||||
return nil, errors.New("request is nil")
|
|
||||||
}
|
|
||||||
return request, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
return nil, errors.New("not support image request")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
err, usage = anthropic.StreamHandler(c, resp)
|
err, usage = anthropic.StreamHandler(c, resp)
|
||||||
|
@ -35,6 +35,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return geminiRequest, nil
|
return geminiRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
return nil, errors.New("not support image request")
|
||||||
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||||
if meta.IsStream {
|
if meta.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
|
105
relay/adaptor/vertexai/imagen/adaptor.go
Normal file
105
relay/adaptor/vertexai/imagen/adaptor.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package imagen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"imagen-3.0-generate-001",
|
||||||
|
}
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
meta := meta.GetByContext(c)
|
||||||
|
|
||||||
|
if request.ResponseFormat != "b64_json" {
|
||||||
|
return nil, errors.New("only support b64_json response format")
|
||||||
|
}
|
||||||
|
if request.N <= 0 {
|
||||||
|
return nil, errors.New("n must be greater than 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch meta.Mode {
|
||||||
|
case relaymode.ImagesGenerations:
|
||||||
|
return convertImageCreateRequest(request)
|
||||||
|
default:
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||||
|
return CreateImageRequest{
|
||||||
|
Instances: []createImageInstance{
|
||||||
|
{
|
||||||
|
Prompt: request.Prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Parameters: createImageParameters{
|
||||||
|
SampleCount: request.N,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
|
||||||
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(
|
||||||
|
errors.Wrap(err, "failed to read response body"),
|
||||||
|
"read_response_body",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, openai.ErrorWrapper(
|
||||||
|
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
|
||||||
|
"upstream_response",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagenResp := new(CreateImageResponse)
|
||||||
|
if err := json.Unmarshal(respBody, imagenResp); err != nil {
|
||||||
|
return nil, openai.ErrorWrapper(
|
||||||
|
errors.Wrap(err, "failed to decode response body"),
|
||||||
|
"unmarshal_upstream_response",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(imagenResp.Predictions) == 0 {
|
||||||
|
return nil, openai.ErrorWrapper(
|
||||||
|
errors.New("empty predictions"),
|
||||||
|
"empty_predictions",
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
oaiResp := openai.ImageResponse{
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
for _, prediction := range imagenResp.Predictions {
|
||||||
|
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
|
||||||
|
B64Json: prediction.BytesBase64Encoded,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, oaiResp)
|
||||||
|
return nil, nil
|
||||||
|
}
|
23
relay/adaptor/vertexai/imagen/model.go
Normal file
23
relay/adaptor/vertexai/imagen/model.go
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package imagen
|
||||||
|
|
||||||
|
type CreateImageRequest struct {
|
||||||
|
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
|
||||||
|
Parameters createImageParameters `json:"parameters" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type createImageInstance struct {
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type createImageParameters struct {
|
||||||
|
SampleCount int `json:"sample_count" binding:"required,min=1"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CreateImageResponse struct {
|
||||||
|
Predictions []createImageResponsePrediction `json:"predictions"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type createImageResponsePrediction struct {
|
||||||
|
MimeType string `json:"mimeType"`
|
||||||
|
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||||
|
}
|
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
|
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
|
||||||
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
|
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
|
||||||
|
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@ -13,8 +14,9 @@ import (
|
|||||||
type VertexAIModelType int
|
type VertexAIModelType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
VerterAIClaude VertexAIModelType = iota + 1
|
VertexAIClaude VertexAIModelType = iota + 1
|
||||||
VerterAIGemini
|
VertexAIGemini
|
||||||
|
VertexAIImagen
|
||||||
)
|
)
|
||||||
|
|
||||||
var modelMapping = map[string]VertexAIModelType{}
|
var modelMapping = map[string]VertexAIModelType{}
|
||||||
@ -23,27 +25,35 @@ var modelList = []string{}
|
|||||||
func init() {
|
func init() {
|
||||||
modelList = append(modelList, claude.ModelList...)
|
modelList = append(modelList, claude.ModelList...)
|
||||||
for _, model := range claude.ModelList {
|
for _, model := range claude.ModelList {
|
||||||
modelMapping[model] = VerterAIClaude
|
modelMapping[model] = VertexAIClaude
|
||||||
}
|
}
|
||||||
|
|
||||||
modelList = append(modelList, gemini.ModelList...)
|
modelList = append(modelList, gemini.ModelList...)
|
||||||
for _, model := range gemini.ModelList {
|
for _, model := range gemini.ModelList {
|
||||||
modelMapping[model] = VerterAIGemini
|
modelMapping[model] = VertexAIGemini
|
||||||
|
}
|
||||||
|
|
||||||
|
modelList = append(modelList, imagen.ModelList...)
|
||||||
|
for _, model := range imagen.ModelList {
|
||||||
|
modelMapping[model] = VertexAIImagen
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type innerAIAdapter interface {
|
type innerAIAdapter interface {
|
||||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||||
|
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAdaptor(model string) innerAIAdapter {
|
func GetAdaptor(model string) innerAIAdapter {
|
||||||
adaptorType := modelMapping[model]
|
adaptorType := modelMapping[model]
|
||||||
switch adaptorType {
|
switch adaptorType {
|
||||||
case VerterAIClaude:
|
case VertexAIClaude:
|
||||||
return &claude.Adaptor{}
|
return &claude.Adaptor{}
|
||||||
case VerterAIGemini:
|
case VertexAIGemini:
|
||||||
return &gemini.Adaptor{}
|
return &gemini.Adaptor{}
|
||||||
|
case VertexAIImagen:
|
||||||
|
return &imagen.Adaptor{}
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
@ -287,6 +287,9 @@ var ModelRatio = map[string]float64{
|
|||||||
"deepl-ja": 25.0 / 1000 * USD,
|
"deepl-ja": 25.0 / 1000 * USD,
|
||||||
// https://console.x.ai/
|
// https://console.x.ai/
|
||||||
"grok-beta": 5.0 / 1000 * USD,
|
"grok-beta": 5.0 / 1000 * USD,
|
||||||
|
// vertex imagen3
|
||||||
|
// https://cloud.google.com/vertex-ai/generative-ai/pricing#imagen-models
|
||||||
|
"imagen-3.0-generate-001": 0.02 * USD,
|
||||||
// replicate charges based on the number of generated images
|
// replicate charges based on the number of generated images
|
||||||
// https://replicate.com/pricing
|
// https://replicate.com/pricing
|
||||||
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||||
|
@ -18,7 +18,7 @@ import (
|
|||||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
metalib "github.com/songquanpeng/one-api/relay/meta"
|
||||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
|||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode {
|
||||||
// check prompt length
|
// check prompt length
|
||||||
if imageRequest.Prompt == "" {
|
if imageRequest.Prompt == "" {
|
||||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||||
@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
|||||||
|
|
||||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
meta := meta.GetByContext(c)
|
meta := metalib.GetByContext(c)
|
||||||
imageRequest, err := getImageRequest(c, meta.Mode)
|
imageRequest, err := getImageRequest(c, meta.Mode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
|
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
|
||||||
@ -116,6 +116,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
meta.OriginModelName = imageRequest.Model
|
meta.OriginModelName = imageRequest.Model
|
||||||
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
|
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
|
||||||
meta.ActualModelName = imageRequest.Model
|
meta.ActualModelName = imageRequest.Model
|
||||||
|
metalib.Set2Context(c, meta)
|
||||||
|
|
||||||
// model validation
|
// model validation
|
||||||
bizErr := validateImageRequest(imageRequest, meta)
|
bizErr := validateImageRequest(imageRequest, meta)
|
||||||
@ -155,8 +156,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
|||||||
case channeltype.Zhipu,
|
case channeltype.Zhipu,
|
||||||
channeltype.Ali,
|
channeltype.Ali,
|
||||||
channeltype.Replicate,
|
channeltype.Replicate,
|
||||||
|
channeltype.VertextAI,
|
||||||
channeltype.Baidu:
|
channeltype.Baidu:
|
||||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
@ -34,6 +34,10 @@ type Meta struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetByContext(c *gin.Context) *Meta {
|
func GetByContext(c *gin.Context) *Meta {
|
||||||
|
if v, ok := c.Get(ctxkey.Meta); ok {
|
||||||
|
return v.(*Meta)
|
||||||
|
}
|
||||||
|
|
||||||
meta := Meta{
|
meta := Meta{
|
||||||
Mode: relaymode.GetByPath(c.Request.URL.Path),
|
Mode: relaymode.GetByPath(c.Request.URL.Path),
|
||||||
ChannelType: c.GetInt(ctxkey.Channel),
|
ChannelType: c.GetInt(ctxkey.Channel),
|
||||||
@ -57,5 +61,11 @@ func GetByContext(c *gin.Context) *Meta {
|
|||||||
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
|
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
|
||||||
}
|
}
|
||||||
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
|
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
|
||||||
|
|
||||||
|
Set2Context(c, &meta)
|
||||||
return &meta
|
return &meta
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Set2Context(c *gin.Context, meta *Meta) {
|
||||||
|
c.Set(ctxkey.Meta, meta)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user