From 297b760293bebcab8edc015a7bfbc38402b87867 Mon Sep 17 00:00:00 2001 From: RockYang Date: Mon, 3 Jun 2024 18:34:37 +0800 Subject: [PATCH 1/9] dalle3 and gptt-4o api compatible with azure --- api/handler/chatimpl/chat_handler.go | 2 +- api/service/dalle/service.go | 41 ++++++++++++++++------------ 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 4ad6965a..9c78f30d 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -330,7 +330,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio Content: prompt, }) req.Input["messages"] = reqMgs - } else if session.Model.Platform == types.OpenAI.Value { // extract image for gpt-vision model + } else if session.Model.Platform == types.OpenAI.Value || session.Model.Platform == types.Azure.Value { // extract image for gpt-vision model imgURLs := utils.ExtractImgURL(prompt) logger.Debugf("detected IMG: %+v", imgURLs) var content interface{} diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index f3e813b2..fa209f20 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -111,11 +111,11 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // translate prompt if utils.HasChinese(task.Prompt) { content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt)) - if err != nil { - return "", fmt.Errorf("error with translate prompt: %v", err) + if err == nil { + prompt = content + logger.Debugf("重写后提示词:%s", prompt) } - prompt = content - logger.Debugf("重写后提示词:%s", prompt) + } var user model.User @@ -126,8 +126,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { // get image generation API KEY var apiKey model.ApiKey - tx := s.db.Where("platform", types.OpenAI.Value). - Where("type", "img"). + tx := s.db.Where("type", "img"). Where("enabled", true). Order("last_used_at ASC").First(&apiKey) if tx.Error != nil { @@ -139,17 +138,23 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { if len(apiKey.ProxyURL) > 5 { s.httpClient.SetProxyURL(apiKey.ProxyURL).R() } - logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL) - r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). - SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(imgReq{ - Model: "dall-e-3", - Prompt: prompt, - N: 1, - Size: task.Size, - Style: task.Style, - Quality: task.Quality, - }). + reqBody := imgReq{ + Model: "dall-e-3", + Prompt: prompt, + N: 1, + Size: task.Size, + Style: task.Style, + Quality: task.Quality, + } + logger.Infof("Sending %s request, ApiURL:%s, API KEY:%s, BODY: %+v", apiKey.Platform, apiKey.ApiURL, apiKey.Value, reqBody) + request := s.httpClient.R().SetHeader("Content-Type", "application/json") + if apiKey.Platform == types.Azure.Value { + request = request.SetHeader("api-key", apiKey.Value) + } else { + request = request.SetHeader("Authorization", "Bearer "+apiKey.Value) + } + r, err := request.SetHeader("Authorization", "Bearer "+apiKey.Value). + SetBody(reqBody). SetErrorResult(&errRes). SetSuccessResult(&res).Post(apiKey.ApiURL) if err != nil { @@ -157,7 +162,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { } if r.IsErrorState() { - return "", fmt.Errorf("error with send request: %v", errRes.Error) + return "", fmt.Errorf("error with send request, status: %s, %+v", r.Status, errRes.Error) } // update the api key last use time s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix()) From 3c70c8ae5988eeff85f348f1d2ad0fc9653ca7b6 Mon Sep 17 00:00:00 2001 From: RockYang Date: Tue, 4 Jun 2024 16:21:08 +0800 Subject: [PATCH 2/9] fixed bug markmap generation --- api/handler/markmap_handler.go | 58 ++++++++++++---------------------- api/service/dalle/service.go | 5 +-- api/utils/openai.go | 2 +- web/src/router.js | 18 +++++++++-- web/src/views/Home.vue | 26 ++++++++------- web/src/views/Index.vue | 17 ++++++---- 6 files changed, 63 insertions(+), 63 deletions(-) diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index bf67ab7b..d6565444 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -183,45 +183,29 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd}) } else { - body, err := io.ReadAll(response.Body) - if err != nil { - return fmt.Errorf("读取响应失败: %v", err) - } - var res types.ApiError - err = json.Unmarshal(body, &res) - if err != nil { - return fmt.Errorf("解析响应失败: %v", err) - } - - // OpenAI API 调用异常处理 - if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") { - // remove key - h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{}) - return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。") - } else if strings.Contains(res.Error.Message, "You exceeded your current quota") { - return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。") - } else { - return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message) - } + body, _ := io.ReadAll(response.Body) + return fmt.Errorf("请求 OpenAI API 失败:%s", string(body)) } // 扣减算力 - res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) - if res.Error == nil { - // 记录算力消费日志 - var u model.User - h.DB.Where("id", userId).First(&u) - h.DB.Create(&model.PowerLog{ - UserId: u.Id, - Username: u.Username, - Type: types.PowerConsume, - Amount: chatModel.Power, - Mark: types.PowerSub, - Balance: u.Power, - Model: chatModel.Value, - Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), - CreatedAt: time.Now(), - }) + if chatModel.Power > 0 { + res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power)) + if res.Error == nil { + // 记录算力消费日志 + var u model.User + h.DB.Where("id", userId).First(&u) + h.DB.Create(&model.PowerLog{ + UserId: u.Id, + Username: u.Username, + Type: types.PowerConsume, + Amount: chatModel.Power, + Mark: types.PowerSub, + Balance: u.Power, + Model: chatModel.Value, + Remark: fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value), + CreatedAt: time.Now(), + }) + } } return nil @@ -235,7 +219,7 @@ func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatMod } // use the last unused key if apiKey.Id == 0 { - res = h.DB.Where("platform", types.OpenAI). + res = h.DB.Where("platform", types.OpenAI.Value). Where("type", "chat"). Where("enabled", true).Order("last_used_at ASC").First(apiKey) } diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index fa209f20..dc66927b 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -153,10 +153,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { } else { request = request.SetHeader("Authorization", "Bearer "+apiKey.Value) } - r, err := request.SetHeader("Authorization", "Bearer "+apiKey.Value). - SetBody(reqBody). - SetErrorResult(&errRes). - SetSuccessResult(&res).Post(apiKey.ApiURL) + r, err := request.SetBody(reqBody).SetErrorResult(&errRes).SetSuccessResult(&res).Post(apiKey.ApiURL) if err != nil { return "", fmt.Errorf("error with send request: %v", err) } diff --git a/api/utils/openai.go b/api/utils/openai.go index 86a976a5..9f012c2f 100644 --- a/api/utils/openai.go +++ b/api/utils/openai.go @@ -54,7 +54,7 @@ type apiErrRes struct { func OpenAIRequest(db *gorm.DB, prompt string) (string, error) { var apiKey model.ApiKey - res := db.Where("platform = ?", types.OpenAI.Value).Where("type", "chat").Where("enabled = ?", true).First(&apiKey) + res := db.Where("platform IN ?", []string{types.OpenAI.Value, types.Azure.Value}).Where("type", "chat").Where("enabled = ?", true).First(&apiKey) if res.Error != nil { return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", res.Error) } diff --git a/web/src/router.js b/web/src/router.js index 64fdecb6..de19cfff 100644 --- a/web/src/router.js +++ b/web/src/router.js @@ -6,19 +6,20 @@ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import {createRouter, createWebHistory} from "vue-router"; +import {ref} from "vue"; +import {httpGet} from "@/utils/http"; const routes = [ { name: 'Index', path: '/', - meta: {title: process.env.VUE_APP_TITLE}, + meta: {title: "首页"}, component: () => import('@/views/Index.vue'), }, { name: 'home', path: '/home', redirect: '/chat', - meta: {title: '首页'}, component: () => import('@/views/Home.vue'), children: [ { @@ -273,11 +274,22 @@ const router = createRouter({ routes: routes, }) +const active = ref(false) +const title = ref('') +httpGet("/api/config/license").then(res => { + active.value = res.data.de_copy +}).catch(() => {}) +httpGet("/api/config/get?key=system").then(res => { + title.value = res.data.title +}).catch(()=>{}) + let prevRoute = null // dynamic change the title when router change router.beforeEach((to, from, next) => { - if (to.meta.title) { + if (!active.value) { document.title = `${to.meta.title} | ${process.env.VUE_APP_TITLE}` + } else { + document.title = `${to.meta.title} | ${title.value}` } prevRoute = from next() diff --git a/web/src/views/Home.vue b/web/src/views/Home.vue index 87043d42..80abdf45 100644 --- a/web/src/views/Home.vue +++ b/web/src/views/Home.vue @@ -46,19 +46,21 @@ {{ loginUser.nickname }} - - - - 用户手册 - - +
+ + + + 用户手册 + + - - - - Geek-AI {{ version }} - - + + + + Geek-AI {{ version }} + + +
diff --git a/web/src/views/Index.vue b/web/src/views/Index.vue index 16f958b5..ae0cdfd6 100644 --- a/web/src/views/Index.vue +++ b/web/src/views/Index.vue @@ -26,8 +26,11 @@ - 登录 - 注册 + + + 登录 + 注册 + @@ -69,6 +72,7 @@ import FooterBar from "@/components/FooterBar.vue"; import {httpGet} from "@/utils/http"; import {ElMessage} from "element-plus"; import {isMobile} from "@/utils/libs"; +import {checkSession} from "@/action/session"; const router = useRouter() @@ -83,6 +87,7 @@ const licenseConfig = ref({}) // const size = Math.max(window.innerWidth * 0.5, window.innerHeight * 0.8) const winHeight = window.innerHeight - 150 const bgClass = ref('fixed-bg') +const isLogin = ref(false) onMounted(() => { httpGet("/api/config/get?key=system").then(res => { @@ -100,11 +105,11 @@ onMounted(() => { }).catch(e => { ElMessage.error("获取 License 配置:" + e.message) }) - init() -}) -const init = () => { -} + checkSession().then(() => { + isLogin.value = true + }).catch(()=>{}) +})