mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 09:16:36 +08:00
feat: support vertex imagen3
This commit is contained in:
parent
3915ce9814
commit
009337ccf3
@ -21,4 +21,5 @@ const (
|
||||
AvailableModels = "available_models"
|
||||
KeyRequestBody = "key_request_body"
|
||||
SystemPrompt = "system_prompt"
|
||||
Meta = "meta"
|
||||
)
|
||||
|
@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return aiProxyLibraryRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -72,7 +72,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -109,7 +109,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -19,7 +19,7 @@ type Adaptor struct {
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ import (
|
||||
type Adaptor struct{}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (*Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
|
@ -38,7 +38,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return convertedRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -65,7 +65,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -13,7 +13,7 @@ type Adaptor interface {
|
||||
GetRequestURL(meta *meta.Meta) (string, error)
|
||||
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
ConvertImageRequest(request *model.ImageRequest) (any, error)
|
||||
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
GetModelList() []string
|
||||
|
@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -85,7 +85,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return ConvertRequest(*request), nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.Errorf("not implement")
|
||||
}
|
||||
|
||||
|
@ -23,7 +23,29 @@ type Adaptor struct {
|
||||
}
|
||||
|
||||
// ConvertImageRequest implements adaptor.Adaptor.
|
||||
func (*Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("should call replicate.ConvertImageRequest instead")
|
||||
}
|
||||
|
||||
func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
if request.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("only support b64_json response format")
|
||||
}
|
||||
if request.N != 1 && request.N != 0 {
|
||||
return nil, errors.New("only support N=1")
|
||||
}
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return convertImageCreateRequest(request)
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||
return DrawImageRequest{
|
||||
Input: ImageInput{
|
||||
Steps: 25,
|
||||
|
@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return tencentRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -1,18 +1,20 @@
|
||||
package vertexai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
relayModel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
var _ adaptor.Adaptor = new(Adaptor)
|
||||
@ -24,14 +26,27 @@ type Adaptor struct{}
|
||||
func (a *Adaptor) Init(meta *meta.Meta) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
if request.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("only support b64_json response format")
|
||||
}
|
||||
|
||||
adaptor := GetAdaptor(request.Model)
|
||||
adaptor := GetAdaptor(meta.ActualModelName)
|
||||
if adaptor == nil {
|
||||
return nil, errors.New("adaptor not found")
|
||||
return nil, errors.Errorf("cannot found vertex image adaptor for model %s", meta.ActualModelName)
|
||||
}
|
||||
|
||||
return adaptor.ConvertImageRequest(c, request)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
adaptor := GetAdaptor(meta.ActualModelName)
|
||||
if adaptor == nil {
|
||||
return nil, errors.Errorf("cannot found vertex chat adaptor for model %s", meta.ActualModelName)
|
||||
}
|
||||
|
||||
return adaptor.ConvertRequest(c, relayMode, request)
|
||||
@ -40,9 +55,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
adaptor := GetAdaptor(meta.ActualModelName)
|
||||
if adaptor == nil {
|
||||
return nil, &relaymodel.ErrorWithStatusCode{
|
||||
return nil, &relayModel.ErrorWithStatusCode{
|
||||
StatusCode: http.StatusInternalServerError,
|
||||
Error: relaymodel.Error{
|
||||
Error: relayModel.Error{
|
||||
Message: "adaptor not found",
|
||||
},
|
||||
}
|
||||
@ -60,14 +75,19 @@ func (a *Adaptor) GetChannelName() string {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
suffix := ""
|
||||
if strings.HasPrefix(meta.ActualModelName, "gemini") {
|
||||
var suffix string
|
||||
switch {
|
||||
case strings.HasPrefix(meta.ActualModelName, "gemini"):
|
||||
if meta.IsStream {
|
||||
suffix = "streamGenerateContent?alt=sse"
|
||||
} else {
|
||||
suffix = "generateContent"
|
||||
}
|
||||
} else {
|
||||
case slices.Contains(imagen.ModelList, meta.ActualModelName):
|
||||
return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/imagen-3.0-generate-001:predict",
|
||||
meta.Config.Region, meta.Config.VertexAIProjectID, meta.Config.Region,
|
||||
), nil
|
||||
default:
|
||||
if meta.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
} else {
|
||||
@ -85,6 +105,7 @@ func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
meta.Config.Region,
|
||||
@ -105,13 +126,6 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
|
||||
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
|
||||
}
|
||||
|
@ -50,6 +50,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not support image request")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
err, usage = anthropic.StreamHandler(c, resp)
|
||||
|
@ -35,6 +35,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return geminiRequest, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not support image request")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
if meta.IsStream {
|
||||
var responseText string
|
||||
|
105
relay/adaptor/vertexai/imagen/adaptor.go
Normal file
105
relay/adaptor/vertexai/imagen/adaptor.go
Normal file
@ -0,0 +1,105 @@
|
||||
package imagen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
var ModelList = []string{
|
||||
"imagen-3.0-generate-001",
|
||||
}
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
meta := meta.GetByContext(c)
|
||||
|
||||
if request.ResponseFormat != "b64_json" {
|
||||
return nil, errors.New("only support b64_json response format")
|
||||
}
|
||||
if request.N <= 0 {
|
||||
return nil, errors.New("n must be greater than 0")
|
||||
}
|
||||
|
||||
switch meta.Mode {
|
||||
case relaymode.ImagesGenerations:
|
||||
return convertImageCreateRequest(request)
|
||||
default:
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
}
|
||||
|
||||
func convertImageCreateRequest(request *model.ImageRequest) (any, error) {
|
||||
return CreateImageRequest{
|
||||
Instances: []createImageInstance{
|
||||
{
|
||||
Prompt: request.Prompt,
|
||||
},
|
||||
},
|
||||
Parameters: createImageParameters{
|
||||
SampleCount: request.N,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, wrapErr *model.ErrorWithStatusCode) {
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.Wrap(err, "failed to read response body"),
|
||||
"read_response_body",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.Errorf("upstream response status code: %d, body: %s", resp.StatusCode, string(respBody)),
|
||||
"upstream_response",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
imagenResp := new(CreateImageResponse)
|
||||
if err := json.Unmarshal(respBody, imagenResp); err != nil {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.Wrap(err, "failed to decode response body"),
|
||||
"unmarshal_upstream_response",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
if len(imagenResp.Predictions) == 0 {
|
||||
return nil, openai.ErrorWrapper(
|
||||
errors.New("empty predictions"),
|
||||
"empty_predictions",
|
||||
http.StatusInternalServerError,
|
||||
)
|
||||
}
|
||||
|
||||
oaiResp := openai.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
for _, prediction := range imagenResp.Predictions {
|
||||
oaiResp.Data = append(oaiResp.Data, openai.ImageData{
|
||||
B64Json: prediction.BytesBase64Encoded,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, oaiResp)
|
||||
return nil, nil
|
||||
}
|
23
relay/adaptor/vertexai/imagen/model.go
Normal file
23
relay/adaptor/vertexai/imagen/model.go
Normal file
@ -0,0 +1,23 @@
|
||||
package imagen
|
||||
|
||||
type CreateImageRequest struct {
|
||||
Instances []createImageInstance `json:"instances" binding:"required,min=1"`
|
||||
Parameters createImageParameters `json:"parameters" binding:"required"`
|
||||
}
|
||||
|
||||
type createImageInstance struct {
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type createImageParameters struct {
|
||||
SampleCount int `json:"sample_count" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
type CreateImageResponse struct {
|
||||
Predictions []createImageResponsePrediction `json:"predictions"`
|
||||
}
|
||||
|
||||
type createImageResponsePrediction struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
claude "github.com/songquanpeng/one-api/relay/adaptor/vertexai/claude"
|
||||
gemini "github.com/songquanpeng/one-api/relay/adaptor/vertexai/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/vertexai/imagen"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
@ -13,8 +14,9 @@ import (
|
||||
type VertexAIModelType int
|
||||
|
||||
const (
|
||||
VerterAIClaude VertexAIModelType = iota + 1
|
||||
VerterAIGemini
|
||||
VertexAIClaude VertexAIModelType = iota + 1
|
||||
VertexAIGemini
|
||||
VertexAIImagen
|
||||
)
|
||||
|
||||
var modelMapping = map[string]VertexAIModelType{}
|
||||
@ -23,27 +25,35 @@ var modelList = []string{}
|
||||
func init() {
|
||||
modelList = append(modelList, claude.ModelList...)
|
||||
for _, model := range claude.ModelList {
|
||||
modelMapping[model] = VerterAIClaude
|
||||
modelMapping[model] = VertexAIClaude
|
||||
}
|
||||
|
||||
modelList = append(modelList, gemini.ModelList...)
|
||||
for _, model := range gemini.ModelList {
|
||||
modelMapping[model] = VerterAIGemini
|
||||
modelMapping[model] = VertexAIGemini
|
||||
}
|
||||
|
||||
modelList = append(modelList, imagen.ModelList...)
|
||||
for _, model := range imagen.ModelList {
|
||||
modelMapping[model] = VertexAIImagen
|
||||
}
|
||||
}
|
||||
|
||||
type innerAIAdapter interface {
|
||||
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
|
||||
ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error)
|
||||
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
|
||||
}
|
||||
|
||||
func GetAdaptor(model string) innerAIAdapter {
|
||||
adaptorType := modelMapping[model]
|
||||
switch adaptorType {
|
||||
case VerterAIClaude:
|
||||
case VertexAIClaude:
|
||||
return &claude.Adaptor{}
|
||||
case VerterAIGemini:
|
||||
case VertexAIGemini:
|
||||
return &gemini.Adaptor{}
|
||||
case VertexAIImagen:
|
||||
return &imagen.Adaptor{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -80,7 +80,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
|
||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
|
@ -287,6 +287,9 @@ var ModelRatio = map[string]float64{
|
||||
"deepl-ja": 25.0 / 1000 * USD,
|
||||
// https://console.x.ai/
|
||||
"grok-beta": 5.0 / 1000 * USD,
|
||||
// vertex imagen3
|
||||
// https://cloud.google.com/vertex-ai/generative-ai/pricing#imagen-models
|
||||
"imagen-3.0-generate-001": 0.02 * USD,
|
||||
// replicate charges based on the number of generated images
|
||||
// https://replicate.com/pricing
|
||||
"black-forest-labs/flux-1.1-pro": 0.04 * USD,
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
metalib "github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
// check prompt length
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := meta.GetByContext(c)
|
||||
meta := metalib.GetByContext(c)
|
||||
imageRequest, err := getImageRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
|
||||
@ -116,6 +116,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
meta.OriginModelName = imageRequest.Model
|
||||
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = imageRequest.Model
|
||||
metalib.Set2Context(c, meta)
|
||||
|
||||
// model validation
|
||||
bizErr := validateImageRequest(imageRequest, meta)
|
||||
@ -155,8 +156,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
case channeltype.Zhipu,
|
||||
channeltype.Ali,
|
||||
channeltype.Replicate,
|
||||
channeltype.VertextAI,
|
||||
channeltype.Baidu:
|
||||
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
|
||||
finalRequest, err := adaptor.ConvertImageRequest(c, imageRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
@ -34,6 +34,10 @@ type Meta struct {
|
||||
}
|
||||
|
||||
func GetByContext(c *gin.Context) *Meta {
|
||||
if v, ok := c.Get(ctxkey.Meta); ok {
|
||||
return v.(*Meta)
|
||||
}
|
||||
|
||||
meta := Meta{
|
||||
Mode: relaymode.GetByPath(c.Request.URL.Path),
|
||||
ChannelType: c.GetInt(ctxkey.Channel),
|
||||
@ -57,5 +61,11 @@ func GetByContext(c *gin.Context) *Meta {
|
||||
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
|
||||
}
|
||||
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
|
||||
|
||||
Set2Context(c, &meta)
|
||||
return &meta
|
||||
}
|
||||
|
||||
func Set2Context(c *gin.Context, meta *Meta) {
|
||||
c.Set(ctxkey.Meta, meta)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user