mirror of
https://github.com/yangjian102621/geekai.git
synced 2026-04-26 13:04:30 +08:00
merge v4.2.6
整合 v4.2.6 的后端中间件与服务层重构、前端样式体系迁移和管理端/移动端功能更新,统一清理历史冲突并完成版本升级。 Made-with: Cursor
This commit is contained in:
@@ -8,35 +8,38 @@ package service
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/imroc/req/v3"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type CaptchaService struct {
|
||||
config types.ApiConfig
|
||||
config types.CaptchaConfig
|
||||
client *req.Client
|
||||
}
|
||||
|
||||
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
|
||||
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
|
||||
return &CaptchaService{
|
||||
config: config,
|
||||
config: captchaConfig,
|
||||
client: req.C().SetTimeout(10 * time.Second),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
func (s *CaptchaService) Get() (interface{}, error) {
|
||||
if s.config.Token == "" {
|
||||
return nil, errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
|
||||
url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||
@@ -49,12 +52,11 @@ func (s *CaptchaService) Get() (interface{}, error) {
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func (s *CaptchaService) Check(data interface{}) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL)
|
||||
func (s *CaptchaService) Check(data any) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetBodyJsonMarshal(data).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
@@ -68,16 +70,11 @@ func (s *CaptchaService) Check(data interface{}) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||
if s.config.Token == "" {
|
||||
return nil, errors.New("无效的 API Token")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
|
||||
func (s *CaptchaService) SlideGet() (any, error) {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetSuccessResult(&res).Get(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return nil, fmt.Errorf("请求 API 失败:%v", err)
|
||||
@@ -90,12 +87,11 @@ func (s *CaptchaService) SlideGet() (interface{}, error) {
|
||||
return res.Data, nil
|
||||
}
|
||||
|
||||
func (s *CaptchaService) SlideCheck(data interface{}) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
|
||||
func (s *CaptchaService) SlideCheck(data any) bool {
|
||||
url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
|
||||
var res types.BizVo
|
||||
r, err := s.client.R().
|
||||
SetHeader("AppId", s.config.AppId).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
|
||||
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
|
||||
SetBodyJsonMarshal(data).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
|
||||
@@ -1,333 +0,0 @@
|
||||
package crawler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/logger"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-rod/rod"
|
||||
"github.com/go-rod/rod/lib/launcher"
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
)
|
||||
|
||||
// Service 网络爬虫服务
|
||||
type Service struct {
|
||||
browser *rod.Browser
|
||||
}
|
||||
|
||||
// NewService 创建一个新的爬虫服务
|
||||
func NewService() (*Service, error) {
|
||||
// 启动浏览器
|
||||
path, _ := launcher.LookPath()
|
||||
u := launcher.New().Bin(path).
|
||||
Headless(true). // 无头模式
|
||||
Set("disable-web-security", ""). // 禁用网络安全限制
|
||||
Set("disable-gpu", ""). // 禁用 GPU 加速
|
||||
Set("no-sandbox", ""). // 禁用沙箱模式
|
||||
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
|
||||
MustLaunch()
|
||||
|
||||
browser := rod.New().ControlURL(u).MustConnect()
|
||||
|
||||
return &Service{
|
||||
browser: browser,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchResult 搜索结果
|
||||
type SearchResult struct {
|
||||
Title string `json:"title"` // 标题
|
||||
URL string `json:"url"` // 链接
|
||||
Content string `json:"content"` // 内容摘要
|
||||
}
|
||||
|
||||
// WebSearch 网络搜索
|
||||
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
|
||||
if keyword == "" {
|
||||
return nil, errors.New("搜索关键词不能为空")
|
||||
}
|
||||
|
||||
if maxPages <= 0 {
|
||||
maxPages = 1
|
||||
}
|
||||
if maxPages > 10 {
|
||||
maxPages = 10 // 最多搜索 10 页
|
||||
}
|
||||
|
||||
results := make([]SearchResult, 0)
|
||||
|
||||
// 使用百度搜索
|
||||
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
|
||||
|
||||
// 设置页面超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 创建页面
|
||||
page := s.browser.MustPage()
|
||||
defer page.MustClose()
|
||||
|
||||
// 设置视口大小
|
||||
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
|
||||
Width: 1280,
|
||||
Height: 800,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("设置视口失败: %v", err)
|
||||
}
|
||||
|
||||
// 导航到搜索页面
|
||||
err = page.Context(ctx).Navigate(searchURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待搜索结果加载完成
|
||||
err = page.WaitLoad()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
|
||||
}
|
||||
|
||||
// 分析当前页面的搜索结果
|
||||
for i := 0; i < maxPages; i++ {
|
||||
if i > 0 {
|
||||
// 点击下一页按钮
|
||||
nextPage, err := page.Element("a.n")
|
||||
if err != nil || nextPage == nil {
|
||||
break // 没有下一页
|
||||
}
|
||||
|
||||
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
|
||||
if err != nil {
|
||||
break // 点击下一页失败
|
||||
}
|
||||
|
||||
// 等待新页面加载
|
||||
err = page.WaitLoad()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 提取搜索结果
|
||||
resultElements, err := page.Elements(".result, .c-container")
|
||||
if err != nil || resultElements == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, result := range resultElements {
|
||||
// 获取标题
|
||||
titleElement, err := result.Element("h3, .t")
|
||||
if err != nil || titleElement == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
title, err := titleElement.Text()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取 URL
|
||||
linkElement, err := titleElement.Element("a")
|
||||
if err != nil || linkElement == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
href, err := linkElement.Attribute("href")
|
||||
if err != nil || href == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取内容摘要 - 尝试多个可能的选择器
|
||||
var contentElement *rod.Element
|
||||
var content string
|
||||
|
||||
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
|
||||
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
|
||||
for _, selector := range selectors {
|
||||
contentElement, err = result.Element(selector)
|
||||
if err == nil && contentElement != nil {
|
||||
content, _ = contentElement.Text()
|
||||
if content != "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果所有选择器都失败,尝试直接从结果块中提取文本
|
||||
if content == "" {
|
||||
// 获取结果元素的所有文本
|
||||
fullText, err := result.Text()
|
||||
if err == nil && fullText != "" {
|
||||
// 简单处理:从全文中移除标题,剩下的可能是摘要
|
||||
fullText = strings.Replace(fullText, title, "", 1)
|
||||
// 清理文本
|
||||
content = strings.TrimSpace(fullText)
|
||||
// 限制内容长度
|
||||
if len(content) > 200 {
|
||||
content = content[:200] + "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加到结果集
|
||||
results = append(results, SearchResult{
|
||||
Title: title,
|
||||
URL: *href,
|
||||
Content: content,
|
||||
})
|
||||
|
||||
// 限制结果数量,每页最多 10 条
|
||||
if len(results) >= 10*maxPages {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取真实 URL(百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL)
|
||||
for i, result := range results {
|
||||
realURL, err := s.getRedirectURL(result.URL)
|
||||
if err == nil && realURL != "" {
|
||||
results[i].URL = realURL
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 获取真实 URL
|
||||
func (s *Service) getRedirectURL(shortURL string) (string, error) {
|
||||
// 创建页面
|
||||
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
defer func() {
|
||||
_ = page.Close()
|
||||
}()
|
||||
|
||||
// 导航到短链接
|
||||
err = page.Navigate(shortURL)
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
|
||||
// 等待重定向完成
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 获取当前 URL
|
||||
info, err := page.Info()
|
||||
if err != nil {
|
||||
return shortURL, err // 返回原始URL
|
||||
}
|
||||
|
||||
return info.URL, nil
|
||||
}
|
||||
|
||||
// Close 关闭浏览器
|
||||
func (s *Service) Close() error {
|
||||
if s.browser != nil {
|
||||
err := s.browser.Close()
|
||||
s.browser = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SearchWeb 封装的搜索方法
|
||||
func SearchWeb(keyword string, maxPages int) (string, error) {
|
||||
// 添加panic恢复机制
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log := logger.GetLogger()
|
||||
log.Errorf("爬虫服务崩溃: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
service, err := NewService()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
|
||||
}
|
||||
defer service.Close()
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用goroutine和通道来处理超时
|
||||
resultChan := make(chan []SearchResult, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
results, err := service.WebSearch(keyword, maxPages)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
resultChan <- results
|
||||
}()
|
||||
|
||||
// 等待结果或超时
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
|
||||
case err := <-errChan:
|
||||
return "", fmt.Errorf("搜索失败: %v", err)
|
||||
case results := <-resultChan:
|
||||
if len(results) == 0 {
|
||||
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
|
||||
}
|
||||
|
||||
// 格式化结果
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
|
||||
|
||||
for i, result := range results {
|
||||
// // 尝试打开链接获取实际内容
|
||||
// page := service.browser.MustPage()
|
||||
// defer page.MustClose()
|
||||
|
||||
// // 设置页面超时
|
||||
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
// defer pageCancel()
|
||||
|
||||
// // 导航到目标页面
|
||||
// err := page.Context(pageCtx).Navigate(result.URL)
|
||||
// if err == nil {
|
||||
// // 等待页面加载
|
||||
// _ = page.WaitLoad()
|
||||
|
||||
// // 获取页面标题
|
||||
// title, err := page.Eval("() => document.title")
|
||||
// if err == nil && title.Value.String() != "" {
|
||||
// result.Title = title.Value.String()
|
||||
// }
|
||||
|
||||
// // 获取页面主要内容
|
||||
// if content, err := page.Element("body"); err == nil {
|
||||
// if text, err := content.Text(); err == nil {
|
||||
// // 清理并截取内容
|
||||
// text = strings.TrimSpace(text)
|
||||
// if len(text) > 200 {
|
||||
// text = text[:200] + "..."
|
||||
// }
|
||||
// result.Prompt = text
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
|
||||
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
|
||||
if result.Content != "" {
|
||||
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
@@ -94,12 +95,14 @@ func (s *Service) Run() {
|
||||
}
|
||||
|
||||
type imgReq struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
Prompt string `json:"prompt"`
|
||||
N int `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
Style string `json:"style,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
}
|
||||
|
||||
type imgRes struct {
|
||||
@@ -122,15 +125,6 @@ type ErrRes struct {
|
||||
|
||||
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
logger.Debugf("绘画参数:%+v", task)
|
||||
prompt := task.Prompt
|
||||
// translate prompt
|
||||
if utils.HasChinese(prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
prompt = content
|
||||
logger.Debugf("重写后提示词:%s", prompt)
|
||||
}
|
||||
}
|
||||
|
||||
var chatModel model.ChatModel
|
||||
if task.ModelId > 0 {
|
||||
@@ -160,12 +154,17 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
|
||||
reqBody := imgReq{
|
||||
Model: chatModel.Value,
|
||||
Prompt: prompt,
|
||||
Prompt: task.Prompt,
|
||||
N: 1,
|
||||
Size: task.Size,
|
||||
Style: task.Style,
|
||||
Quality: task.Quality,
|
||||
}
|
||||
// 图片编辑
|
||||
if len(task.Image) > 0 {
|
||||
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
|
||||
}
|
||||
|
||||
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
|
||||
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
@@ -188,7 +187,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
var imgURL string
|
||||
var data = map[string]interface{}{
|
||||
"progress": 100,
|
||||
"prompt": prompt,
|
||||
"prompt": task.Prompt,
|
||||
}
|
||||
// 如果返回的是base64,则需要上传到oss
|
||||
if res.Data[0].B64Json != "" {
|
||||
@@ -210,11 +209,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
|
||||
|
||||
var content string
|
||||
if sync {
|
||||
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with download image: %v", err)
|
||||
}
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
|
||||
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", task.Prompt, imgURL)
|
||||
}
|
||||
|
||||
return content, nil
|
||||
|
||||
@@ -3,8 +3,10 @@ package jimeng
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/volcengine/volc-sdk-golang/base"
|
||||
"github.com/volcengine/volc-sdk-golang/service/visual"
|
||||
@@ -13,14 +15,22 @@ import (
|
||||
// Client 即梦API客户端
|
||||
type Client struct {
|
||||
visual *visual.Visual
|
||||
config types.JimengConfig
|
||||
}
|
||||
|
||||
// NewClient 创建即梦API客户端
|
||||
func NewClient(accessKey, secretKey string) *Client {
|
||||
func NewClient(sysConfig *types.SystemConfig) *Client {
|
||||
|
||||
client := &Client{}
|
||||
client.UpdateConfig(sysConfig.Jimeng)
|
||||
return client
|
||||
}
|
||||
|
||||
func (c *Client) UpdateConfig(config types.JimengConfig) error {
|
||||
// 使用官方SDK的visual实例
|
||||
visualInstance := visual.NewInstance()
|
||||
visualInstance.Client.SetAccessKey(accessKey)
|
||||
visualInstance.Client.SetSecretKey(secretKey)
|
||||
visualInstance.Client.SetAccessKey(config.AccessKey)
|
||||
visualInstance.Client.SetSecretKey(config.SecretKey)
|
||||
|
||||
// 添加即梦AI专有的API配置
|
||||
jimengApis := map[string]*base.ApiInfo{
|
||||
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
|
||||
visualInstance.Client.ApiInfoList[name] = info
|
||||
}
|
||||
|
||||
return &Client{
|
||||
visual: visualInstance,
|
||||
c.config = config
|
||||
c.visual = visualInstance
|
||||
|
||||
return c.testConnection()
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (c *Client) testConnection() error {
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := c.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SubmitTask 提交异步任务
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -16,8 +15,6 @@ import (
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
@@ -36,17 +33,8 @@ type Service struct {
|
||||
}
|
||||
|
||||
// NewService 创建即梦服务
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
|
||||
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
|
||||
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
|
||||
// 从数据库加载配置
|
||||
var config model.Config
|
||||
db.Where("name = ?", "Jimeng").First(&config)
|
||||
var jimengConfig types.JimengConfig
|
||||
if config.Id > 0 {
|
||||
_ = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
}
|
||||
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Service{
|
||||
db: db,
|
||||
@@ -378,7 +366,7 @@ func (s *Service) pollTaskStatus() {
|
||||
|
||||
for _, job := range jobs {
|
||||
// 任务超时处理
|
||||
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
|
||||
if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
|
||||
s.handleTaskError(job.Id, "task timeout")
|
||||
continue
|
||||
}
|
||||
@@ -391,7 +379,7 @@ func (s *Service) pollTaskStatus() {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logger.Errorf("query jimeng task status failed: %v", err)
|
||||
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -446,9 +434,7 @@ func (s *Service) pollTaskStatus() {
|
||||
s.handleTaskError(job.Id, "task not found")
|
||||
|
||||
case model.JMTaskStatusExpired:
|
||||
// 任务过期
|
||||
s.handleTaskError(job.Id, "task expired")
|
||||
|
||||
continue
|
||||
default:
|
||||
logger.Warnf("unknown task status: %s", resp.Data.Status)
|
||||
}
|
||||
@@ -524,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
|
||||
}
|
||||
return &job, nil
|
||||
}
|
||||
|
||||
// testConnection 测试即梦AI连接
|
||||
func (s *Service) testConnection(accessKey, secretKey string) error {
|
||||
testClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 使用一个简单的查询任务来测试连接
|
||||
testReq := &QueryTaskRequest{
|
||||
ReqKey: "test_connection",
|
||||
TaskId: "test_task_id_12345",
|
||||
}
|
||||
|
||||
_, err := testClient.QueryTask(testReq)
|
||||
// 即使任务不存在,只要不是认证错误就说明连接正常
|
||||
if err != nil {
|
||||
// 检查是否是认证错误
|
||||
if strings.Contains(err.Error(), "InvalidAccessKey") {
|
||||
return fmt.Errorf("认证失败,请检查AccessKey和SecretKey是否正确")
|
||||
}
|
||||
// 其他错误(如任务不存在)说明连接正常
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateClientConfig 更新客户端配置
|
||||
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
|
||||
// 创建新的客户端
|
||||
newClient := NewClient(accessKey, secretKey)
|
||||
|
||||
// 测试新客户端是否可用
|
||||
err := s.testConnection(accessKey, secretKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新客户端
|
||||
s.client = newClient
|
||||
return nil
|
||||
}
|
||||
|
||||
var defaultPower = types.JimengPower{
|
||||
TextToImage: 20,
|
||||
ImageToImage: 20,
|
||||
ImageEdit: 20,
|
||||
ImageEffects: 20,
|
||||
TextToVideo: 300,
|
||||
ImageToVideo: 300,
|
||||
}
|
||||
|
||||
// GetConfig 获取即梦AI配置
|
||||
func (s *Service) GetConfig() *types.JimengConfig {
|
||||
var config model.Config
|
||||
err := s.db.Where("name", "jimeng").First(&config).Error
|
||||
if err != nil {
|
||||
// 如果配置不存在,返回默认配置
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
var jimengConfig types.JimengConfig
|
||||
err = utils.JsonDecode(config.Value, &jimengConfig)
|
||||
if err != nil {
|
||||
return &types.JimengConfig{
|
||||
AccessKey: "",
|
||||
SecretKey: "",
|
||||
Power: defaultPower,
|
||||
}
|
||||
}
|
||||
|
||||
return &jimengConfig
|
||||
}
|
||||
|
||||
@@ -8,30 +8,37 @@ package service
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"geekai/utils"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/shirou/gopsutil/host"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LicenseService struct {
|
||||
config types.ApiConfig
|
||||
levelDB *store.LevelDB
|
||||
license *types.License
|
||||
urlWhiteList []string
|
||||
machineId string
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
|
||||
var license types.License
|
||||
func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
|
||||
var machineId string
|
||||
info, err := host.Info()
|
||||
if err == nil {
|
||||
machineId = info.HostID
|
||||
}
|
||||
logger.Infof("License: %+v", sysConfig.License)
|
||||
return &LicenseService{
|
||||
config: server.Config.ApiConfig,
|
||||
levelDB: levelDB,
|
||||
license: &license,
|
||||
machineId: "",
|
||||
license: &sysConfig.License,
|
||||
machineId: machineId,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,15 +53,15 @@ type License struct {
|
||||
}
|
||||
|
||||
// ActiveLicense 激活 License
|
||||
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||
func (s *LicenseService) ActiveLicense(license string) error {
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": license, "machine_id": machineId}).
|
||||
SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送激活请求失败: %v", err)
|
||||
@@ -68,17 +75,24 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
|
||||
return fmt.Errorf("激活失败:%v", res.Message)
|
||||
}
|
||||
|
||||
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
|
||||
return fmt.Errorf("License 已过期")
|
||||
}
|
||||
|
||||
s.license = &types.License{
|
||||
Key: license,
|
||||
MachineId: machineId,
|
||||
MachineId: s.machineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
IsActive: true,
|
||||
}
|
||||
err = s.levelDB.Put(types.LicenseKey, s.license)
|
||||
|
||||
// 保存 License 到数据库
|
||||
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("保存许可证书失败:%v", err)
|
||||
return fmt.Errorf("保存 License 到数据库失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,6 +110,11 @@ func (s *LicenseService) SyncLicense() {
|
||||
s.license.IsActive = false
|
||||
} else {
|
||||
s.license = license
|
||||
// 保存 License 到数据库
|
||||
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
|
||||
if err != nil {
|
||||
logger.Errorf("保存 License 到数据库失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
@@ -109,33 +128,30 @@ func (s *LicenseService) SyncLicense() {
|
||||
}
|
||||
|
||||
func (s *LicenseService) fetchLicense() (*types.License, error) {
|
||||
//var res struct {
|
||||
// Code types.BizCode `json:"code"`
|
||||
// Message string `json:"message"`
|
||||
// Data License `json:"data"`
|
||||
//}
|
||||
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
|
||||
//response, err := req.C().R().
|
||||
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
// SetSuccessResult(&res).Post(apiURL)
|
||||
//if err != nil {
|
||||
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
|
||||
//}
|
||||
//if response.IsErrorState() {
|
||||
// return nil, fmt.Errorf("激活失败:%v", response.Status)
|
||||
//}
|
||||
//if res.Code != types.Success {
|
||||
// return nil, fmt.Errorf("激活失败:%v", res.Message)
|
||||
//}
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data License `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
|
||||
response, err := req.C().R().
|
||||
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
|
||||
SetSuccessResult(&res).Post(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("License 同步失败: %v", err)
|
||||
}
|
||||
if response.IsErrorState() {
|
||||
return nil, fmt.Errorf("License 同步失败:%v", response.Status)
|
||||
}
|
||||
if res.Code != types.Success {
|
||||
return nil, fmt.Errorf("License 同步失败:%v", res.Message)
|
||||
}
|
||||
|
||||
return &types.License{
|
||||
Key: "abc",
|
||||
MachineId: "abc",
|
||||
Configs: types.LicenseConfig{
|
||||
UserNum: 10000,
|
||||
DeCopy: false,
|
||||
},
|
||||
ExpiredAt: 0,
|
||||
Key: res.Data.License,
|
||||
MachineId: res.Data.MachineId,
|
||||
Configs: res.Data.Configs,
|
||||
ExpiredAt: res.Data.ExpiredAt,
|
||||
IsActive: true,
|
||||
}, nil
|
||||
}
|
||||
@@ -146,7 +162,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
Message string `json:"message"`
|
||||
Data []string `json:"data"`
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
|
||||
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
|
||||
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %v", err)
|
||||
@@ -163,35 +179,46 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
|
||||
|
||||
// GetLicense 获取许可信息
|
||||
func (s *LicenseService) GetLicense() *types.License {
|
||||
if s.license == nil {
|
||||
var config model.Config
|
||||
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
|
||||
if config.Value != "" {
|
||||
utils.JsonDecode(config.Value, &s.license)
|
||||
}
|
||||
}
|
||||
return s.license
|
||||
}
|
||||
|
||||
func (s *LicenseService) SetLicense(licenseKey string) {
|
||||
s.license.Key = licenseKey
|
||||
|
||||
}
|
||||
|
||||
// IsValidApiURL 判断是否合法的中转 URL
|
||||
func (s *LicenseService) IsValidApiURL(uri string) error {
|
||||
// 获得许可授权的直接放行
|
||||
return nil
|
||||
//if s.license.IsActive {
|
||||
// if s.license.MachineId != s.machineId {
|
||||
// return errors.New("系统使用了盗版的许可证书")
|
||||
// }
|
||||
//
|
||||
// if time.Now().Unix() > s.license.ExpiredAt {
|
||||
// return errors.New("系统许可证书已经过期")
|
||||
// }
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//if len(s.urlWhiteList) == 0 {
|
||||
// urls, err := s.fetchUrlWhiteList()
|
||||
// if err == nil {
|
||||
// s.urlWhiteList = urls
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//for _, v := range s.urlWhiteList {
|
||||
// if strings.HasPrefix(uri, v) {
|
||||
// return nil
|
||||
// }
|
||||
//}
|
||||
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
if s.license.IsActive {
|
||||
if s.license.MachineId != s.machineId {
|
||||
return errors.New("系统使用了盗版的许可证书")
|
||||
}
|
||||
|
||||
if time.Now().Unix() > s.license.ExpiredAt {
|
||||
return errors.New("系统许可证书已经过期")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(s.urlWhiteList) == 0 {
|
||||
urls, err := s.fetchUrlWhiteList()
|
||||
if err == nil {
|
||||
s.urlWhiteList = urls
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range s.urlWhiteList {
|
||||
if strings.HasPrefix(uri, v) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
|
||||
}
|
||||
|
||||
@@ -1,52 +1,342 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// Use of this source code is governed by a Apache-2.0 license
|
||||
// that can be found in the LICENSE file.
|
||||
// @Author yangjian102621@163.com
|
||||
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store"
|
||||
"geekai/store/model"
|
||||
"strings"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
// 迁移状态Redis key
|
||||
MigrationStatusKey = "config_migration:status"
|
||||
// 迁移完成标志
|
||||
MigrationCompleted = "completed"
|
||||
)
|
||||
|
||||
// MigrationService 配置迁移服务
|
||||
type MigrationService struct {
|
||||
db *gorm.DB
|
||||
db *gorm.DB
|
||||
redisClient *redis.Client
|
||||
appConfig *types.AppConfig
|
||||
levelDB *store.LevelDB
|
||||
licenseService *LicenseService
|
||||
}
|
||||
|
||||
func NewMigrationService(db *gorm.DB) *MigrationService {
|
||||
return &MigrationService{db: db}
|
||||
func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
|
||||
return &MigrationService{
|
||||
db: db,
|
||||
redisClient: redisClient,
|
||||
appConfig: appConfig,
|
||||
levelDB: levelDB,
|
||||
licenseService: licenseService,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MigrationService) Migrate() error {
|
||||
err := s.db.AutoMigrate(
|
||||
&model.AdminUser{},
|
||||
&model.ApiKey{},
|
||||
&model.AppType{},
|
||||
&model.ChatItem{},
|
||||
&model.ChatMessage{},
|
||||
&model.ChatModel{},
|
||||
&model.ChatRole{},
|
||||
&model.Config{},
|
||||
&model.DallJob{},
|
||||
&model.File{},
|
||||
&model.Function{},
|
||||
&model.InviteCode{},
|
||||
&model.InviteLog{},
|
||||
&model.Menu{},
|
||||
&model.MidJourneyJob{},
|
||||
&model.Order{},
|
||||
&model.PowerLog{},
|
||||
&model.Product{},
|
||||
&model.Redeem{},
|
||||
&model.SdJob{},
|
||||
&model.SunoJob{},
|
||||
&model.User{},
|
||||
&model.UserLoginLog{},
|
||||
&model.VideoJob{},
|
||||
)
|
||||
return err
|
||||
func (s *MigrationService) StartMigrate() {
|
||||
go func() {
|
||||
s.MigrateConfig(s.appConfig)
|
||||
s.TableMigration()
|
||||
s.MigrateLicense()
|
||||
}()
|
||||
}
|
||||
|
||||
// 迁移 License
|
||||
func (s *MigrationService) MigrateLicense() {
|
||||
key := "migrate:license"
|
||||
if s.redisClient.Get(context.Background(), key).Val() == "1" {
|
||||
logger.Info("License 已迁移,跳过迁移")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("开始迁移 License...")
|
||||
var license types.License
|
||||
err := s.levelDB.Get(types.LicenseKey, &license)
|
||||
if err != nil {
|
||||
license = types.License{
|
||||
Key: "",
|
||||
MachineId: "",
|
||||
Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
|
||||
ExpiredAt: 0,
|
||||
IsActive: false,
|
||||
}
|
||||
}
|
||||
logger.Infof("迁移 License: %+v", license)
|
||||
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
|
||||
logger.Errorf("迁移 License 失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.licenseService.SetLicense(license.Key)
|
||||
logger.Info("迁移 License 完成")
|
||||
s.redisClient.Set(context.Background(), key, "1", 0)
|
||||
}
|
||||
|
||||
// 迁移配置内容
|
||||
func (s *MigrationService) MigrateConfigContent() error {
|
||||
// 用户协议
|
||||
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
|
||||
"content": "用户协议内容",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
// 隐私政策
|
||||
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
|
||||
"content": "隐私政策内容",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
// 思维导图
|
||||
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
|
||||
"content": `# GeekAI 演示站
|
||||
|
||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
|
||||
- 基于 Websocket 实现,完美的打字机体验。
|
||||
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
|
||||
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
|
||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
|
||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
|
||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
|
||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 微信登录配置
|
||||
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
|
||||
"api_key": "",
|
||||
"notify_url": "",
|
||||
"enabled": "false",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证码配置
|
||||
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
|
||||
"api_key": "",
|
||||
"type": "dot",
|
||||
"enabled": "false",
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
// 文本审核
|
||||
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
|
||||
"enable": "false",
|
||||
"active": "gitee",
|
||||
"enable_guide": "false",
|
||||
"guide_prompt": "",
|
||||
"gitee": map[string]string{
|
||||
"api_key": "",
|
||||
"model": "Security-semantic-filtering",
|
||||
},
|
||||
"baidu": map[string]string{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
"tencent": map[string]string{
|
||||
"access_key": "",
|
||||
"secret_key": "",
|
||||
},
|
||||
}); err != nil {
|
||||
return fmt.Errorf("迁移配置内容失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 数据表迁移
|
||||
func (s *MigrationService) TableMigration() {
|
||||
// 新数据表
|
||||
s.db.AutoMigrate(&model.Moderation{})
|
||||
|
||||
// 订单字段整理
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
|
||||
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
|
||||
}
|
||||
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
|
||||
s.db.Migrator().AddColumn(&model.Order{}, "checked")
|
||||
}
|
||||
|
||||
// 重命名 config 表字段
|
||||
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
|
||||
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
|
||||
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
|
||||
}
|
||||
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
|
||||
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
|
||||
}
|
||||
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
|
||||
s.db.Migrator().DropIndex(&model.Config{}, "marker")
|
||||
}
|
||||
|
||||
// 手动删除字段
|
||||
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
|
||||
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
|
||||
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
|
||||
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
|
||||
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "discount")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "days")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
|
||||
}
|
||||
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
|
||||
s.db.Migrator().DropColumn(&model.Product{}, "url")
|
||||
}
|
||||
}
|
||||
|
||||
// 迁移配置数据
|
||||
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
|
||||
|
||||
logger.Info("开始迁移配置到数据库...")
|
||||
|
||||
// 迁移支付配置
|
||||
if err := s.migratePaymentConfig(config); err != nil {
|
||||
logger.Errorf("迁移支付配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移存储配置
|
||||
if err := s.migrateStorageConfig(config); err != nil {
|
||||
logger.Errorf("迁移存储配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移通信配置
|
||||
if err := s.migrateCommunicationConfig(config); err != nil {
|
||||
logger.Errorf("迁移通信配置失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 迁移配置内容
|
||||
if err := s.MigrateConfigContent(); err != nil {
|
||||
logger.Errorf("迁移配置内容失败: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Info("配置迁移完成")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移支付配置
|
||||
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
|
||||
|
||||
paymentConfig := types.PaymentConfig{
|
||||
Alipay: config.AlipayConfig,
|
||||
Epay: config.GeekPayConfig,
|
||||
WxPay: config.WechatPayConfig,
|
||||
}
|
||||
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 迁移存储配置
|
||||
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
|
||||
|
||||
ossConfig := types.OSSConfig{
|
||||
Active: config.OSS.Active,
|
||||
Local: config.OSS.Local,
|
||||
Minio: config.OSS.Minio,
|
||||
QiNiu: config.OSS.QiNiu,
|
||||
AliYun: config.OSS.AliYun,
|
||||
}
|
||||
return s.saveConfig(types.ConfigKeyOss, ossConfig)
|
||||
}
|
||||
|
||||
// 迁移通信配置
|
||||
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
|
||||
// SMTP配置
|
||||
smtpConfig := map[string]any{
|
||||
"use_tls": config.SmtpConfig.UseTls,
|
||||
"host": config.SmtpConfig.Host,
|
||||
"port": config.SmtpConfig.Port,
|
||||
"app_name": config.SmtpConfig.AppName,
|
||||
"from": config.SmtpConfig.From,
|
||||
"password": config.SmtpConfig.Password,
|
||||
}
|
||||
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 短信配置
|
||||
smsConfig := map[string]any{
|
||||
"active": strings.ToLower(config.SMS.Active),
|
||||
"aliyun": map[string]any{
|
||||
"access_key": config.SMS.Ali.AccessKey,
|
||||
"access_secret": config.SMS.Ali.AccessSecret,
|
||||
"sign": config.SMS.Ali.Sign,
|
||||
"code_temp_id": config.SMS.Ali.CodeTempId,
|
||||
},
|
||||
"bao": map[string]any{
|
||||
"username": config.SMS.Bao.Username,
|
||||
"password": config.SMS.Bao.Password,
|
||||
"sign": config.SMS.Bao.Sign,
|
||||
"code_template": config.SMS.Bao.CodeTemplate,
|
||||
},
|
||||
}
|
||||
return s.saveConfig(types.ConfigKeySms, smsConfig)
|
||||
}
|
||||
|
||||
// 保存配置到数据库
|
||||
func (s *MigrationService) saveConfig(key string, config any) error {
|
||||
// 检查是否已存在
|
||||
var existingConfig model.Config
|
||||
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
|
||||
// 配置已存在,跳过
|
||||
logger.Infof("配置 %s 已存在,跳过迁移", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 序列化配置
|
||||
configJSON, err := json.Marshal(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
newConfig := model.Config{
|
||||
Name: key,
|
||||
Value: string(configJSON),
|
||||
}
|
||||
if err := s.db.Create(&newConfig).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logger.Infof("成功迁移配置 %s", key)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -67,25 +67,6 @@ func (s *Service) Run() {
|
||||
continue
|
||||
}
|
||||
|
||||
// translate prompt
|
||||
if utils.HasChinese(task.Prompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.Prompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
// translate negative prompt
|
||||
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
|
||||
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
|
||||
if err == nil {
|
||||
task.NegPrompt = content
|
||||
} else {
|
||||
logger.Warnf("error with translate prompt: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// use fast mode as default
|
||||
if task.Mode == "" {
|
||||
task.Mode = "fast"
|
||||
|
||||
33
api/service/moderation/baidu_moderation.go
Normal file
33
api/service/moderation/baidu_moderation.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
type BaiduAIModeration struct {
|
||||
config types.ModerationBaiduConfig
|
||||
}
|
||||
|
||||
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
|
||||
return &BaiduAIModeration{
|
||||
config: sysConfig.Moderation.Baidu,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
return types.ModerationResult{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ Service = (*BaiduAIModeration)(nil)
|
||||
58
api/service/moderation/gitee_moderation.go
Normal file
58
api/service/moderation/gitee_moderation.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type GiteeAIModeration struct {
|
||||
config types.ModerationGiteeConfig
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
|
||||
return &GiteeAIModeration{
|
||||
config: sysConfig.Moderation.Gitee,
|
||||
apiURL: "https://ai.gitee.com/v1/moderations",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
type GiteeAIModerationResult struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Results []types.ModerationResult `json:"results"`
|
||||
}
|
||||
|
||||
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
|
||||
body := map[string]any{
|
||||
"input": text,
|
||||
"model": s.config.Model,
|
||||
}
|
||||
var res GiteeAIModerationResult
|
||||
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
|
||||
if err != nil {
|
||||
return types.ModerationResult{}, err
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
return types.ModerationResult{}, errors.New(r.String())
|
||||
}
|
||||
|
||||
return res.Results[0], nil
|
||||
}
|
||||
|
||||
var _ Service = (*GiteeAIModeration)(nil)
|
||||
58
api/service/moderation/moderation_manager.go
Normal file
58
api/service/moderation/moderation_manager.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package moderation
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type Service interface {
|
||||
Moderate(text string) (types.ModerationResult, error)
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
gitee *GiteeAIModeration
|
||||
baidu *BaiduAIModeration
|
||||
tencent *TencentAIModeration
|
||||
active string
|
||||
}
|
||||
|
||||
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
|
||||
return &ServiceManager{
|
||||
gitee: gitee,
|
||||
baidu: baidu,
|
||||
tencent: tencent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) GetService() Service {
|
||||
switch s.active {
|
||||
case types.ModerationBaidu:
|
||||
return s.baidu
|
||||
case types.ModerationTencent:
|
||||
return s.tencent
|
||||
default:
|
||||
return s.gitee
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
|
||||
switch config.Active {
|
||||
case types.ModerationGitee:
|
||||
s.gitee.UpdateConfig(config.Gitee)
|
||||
case types.ModerationBaidu:
|
||||
s.baidu.UpdateConfig(config.Baidu)
|
||||
case types.ModerationTencent:
|
||||
s.tencent.UpdateConfig(config.Tencent)
|
||||
}
|
||||
s.active = config.Active
|
||||
}
|
||||
33
api/service/moderation/tencent_moderation.go
Normal file
33
api/service/moderation/tencent_moderation.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package moderation
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"geekai/core/types"
|
||||
)
|
||||
|
||||
type TencentAIModeration struct {
|
||||
config types.ModerationTencentConfig
|
||||
}
|
||||
|
||||
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
|
||||
return &TencentAIModeration{
|
||||
config: sysConfig.Moderation.Tencent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
|
||||
return types.ModerationResult{}, errors.New("not implemented")
|
||||
}
|
||||
|
||||
var _ Service = (*TencentAIModeration)(nil)
|
||||
@@ -23,35 +23,35 @@ import (
|
||||
)
|
||||
|
||||
type AliYunOss struct {
|
||||
config *types.AliYunOssConfig
|
||||
config types.AliYunOssConfig
|
||||
bucket *oss.Bucket
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||
config := &appConfig.OSS.AliYun
|
||||
// 创建 OSS 客户端
|
||||
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
|
||||
s := &AliYunOss{
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
err := s.UpdateConfig(sysConfig.OSS.AliYun)
|
||||
if err != nil {
|
||||
logger.Warnf("阿里云OSS初始化失败: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
|
||||
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取存储空间
|
||||
bucket, err := client.Bucket(config.Bucket)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
|
||||
return &AliYunOss{
|
||||
config: config,
|
||||
bucket: bucket,
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}, nil
|
||||
|
||||
s.bucket = bucket
|
||||
s.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer src.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件
|
||||
err = s.bucket.PutObject(objectKey, src)
|
||||
if err != nil {
|
||||
@@ -102,7 +102,7 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
|
||||
if err != nil {
|
||||
@@ -116,7 +116,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
// 上传文件字节数据
|
||||
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
|
||||
if err != nil {
|
||||
@@ -128,8 +128,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s AliYunOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
objectKey = filepath.Base(fileURL)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
@@ -21,17 +21,21 @@ import (
|
||||
)
|
||||
|
||||
type LocalStorage struct {
|
||||
config *types.LocalStorageConfig
|
||||
config types.LocalStorageConfig
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewLocalStorage(config *types.AppConfig) LocalStorage {
|
||||
return LocalStorage{
|
||||
config: &config.OSS.Local,
|
||||
proxyURL: config.ProxyURL,
|
||||
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
|
||||
return &LocalStorage{
|
||||
config: sysConfig.OSS.Local,
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
file, err := ctx.FormFile(name)
|
||||
if err != nil {
|
||||
|
||||
@@ -24,24 +24,32 @@ import (
|
||||
)
|
||||
|
||||
type MiniOss struct {
|
||||
config *types.MiniOssConfig
|
||||
config types.MiniOssConfig
|
||||
client *minio.Client
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
|
||||
config := &appConfig.OSS.Minio
|
||||
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
|
||||
|
||||
s := &MiniOss{proxyURL: appConfig.ProxyURL}
|
||||
err := s.UpdateConfig(sysConfig.OSS.Minio)
|
||||
if err != nil {
|
||||
logger.Warnf("MinioOSS初始化失败: %v", err)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
|
||||
minioClient, err := minio.New(config.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
|
||||
Secure: config.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return MiniOss{}, err
|
||||
return err
|
||||
}
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
|
||||
s.config = config
|
||||
s.client = minioClient
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
@@ -62,7 +70,7 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -89,7 +97,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer fileReader.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
|
||||
ContentType: file.Header.Get("Body-Type"),
|
||||
})
|
||||
@@ -111,7 +119,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
info, err := s.client.PutObject(
|
||||
context.Background(),
|
||||
s.config.Bucket,
|
||||
@@ -128,8 +136,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s MiniOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
objectKey = filepath.Base(fileURL)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
@@ -24,18 +24,24 @@ import (
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
)
|
||||
|
||||
type QinNiuOss struct {
|
||||
config *types.QiNiuOssConfig
|
||||
type QiNiuOss struct {
|
||||
config types.QiNiuOssConfig
|
||||
mac *qbox.Mac
|
||||
putPolicy storage.PutPolicy
|
||||
uploader *storage.FormUploader
|
||||
manager *storage.BucketManager
|
||||
bucket *storage.BucketManager
|
||||
proxyURL string
|
||||
}
|
||||
|
||||
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
||||
config := &appConfig.OSS.QiNiu
|
||||
// build storage uploader
|
||||
func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiNiuOss {
|
||||
s := &QiNiuOss{
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
s.UpdateConfig(sysConfig.OSS.QiNiu)
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *QiNiuOss) UpdateConfig(config types.QiNiuOssConfig) {
|
||||
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
|
||||
if !ok {
|
||||
zone = storage.ZoneHuanan
|
||||
@@ -47,20 +53,13 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: config.Bucket,
|
||||
}
|
||||
if config.SubDir == "" {
|
||||
config.SubDir = "gpt"
|
||||
}
|
||||
return QinNiuOss{
|
||||
config: config,
|
||||
mac: mac,
|
||||
putPolicy: putPolicy,
|
||||
uploader: formUploader,
|
||||
manager: storage.NewBucketManager(mac, &storeConfig),
|
||||
proxyURL: appConfig.ProxyURL,
|
||||
}
|
||||
s.config = config
|
||||
s.mac = mac
|
||||
s.putPolicy = putPolicy
|
||||
s.uploader = formUploader
|
||||
s.bucket = storage.NewBucketManager(mac, &storeConfig)
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
func (s QiNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
// 解析表单
|
||||
file, err := ctx.FormFile(name)
|
||||
if err != nil {
|
||||
@@ -74,7 +73,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
defer src.Close()
|
||||
|
||||
fileExt := filepath.Ext(file.Filename)
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
|
||||
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
|
||||
// 上传文件
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
@@ -93,7 +92,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
func (s QiNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
|
||||
var fileData []byte
|
||||
var err error
|
||||
if useProxy {
|
||||
@@ -111,7 +110,7 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
if ext == "" {
|
||||
ext = filepath.Ext(parse.Path)
|
||||
}
|
||||
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
|
||||
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
@@ -122,12 +121,12 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
func (s QiNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
imageData, err := base64.StdEncoding.DecodeString(base64Img)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error decoding base64:%v", err)
|
||||
}
|
||||
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
|
||||
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
|
||||
ret := storage.PutRet{}
|
||||
extra := storage.PutExtra{}
|
||||
// 上传文件字节数据
|
||||
@@ -138,16 +137,15 @@ func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
|
||||
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
|
||||
}
|
||||
|
||||
func (s QinNiuOss) Delete(fileURL string) error {
|
||||
func (s QiNiuOss) Delete(fileURL string) error {
|
||||
var objectKey string
|
||||
if strings.HasPrefix(fileURL, "http") {
|
||||
filename := filepath.Base(fileURL)
|
||||
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
|
||||
objectKey = filepath.Base(fileURL)
|
||||
} else {
|
||||
objectKey = fileURL
|
||||
}
|
||||
|
||||
return s.manager.Delete(s.config.Bucket, objectKey)
|
||||
return s.bucket.Delete(s.config.Bucket, objectKey)
|
||||
}
|
||||
|
||||
var _ Uploader = QinNiuOss{}
|
||||
var _ Uploader = QiNiuOss{}
|
||||
|
||||
@@ -9,10 +9,10 @@ package oss
|
||||
|
||||
import "github.com/gin-gonic/gin"
|
||||
|
||||
const Local = "LOCAL"
|
||||
const Minio = "MINIO"
|
||||
const QiNiu = "QINIU"
|
||||
const AliYun = "ALIYUN"
|
||||
const Local = "local"
|
||||
const Minio = "minio"
|
||||
const QiNiu = "qiniu"
|
||||
const AliYun = "aliyun"
|
||||
|
||||
type File struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -9,45 +9,58 @@ package oss
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
"strings"
|
||||
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type UploaderManager struct {
|
||||
handler Uploader
|
||||
local *LocalStorage
|
||||
aliyun *AliYunOss
|
||||
mini *MiniOss
|
||||
qiniu *QiNiuOss
|
||||
active string
|
||||
}
|
||||
|
||||
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
|
||||
active := Local
|
||||
if config.OSS.Active != "" {
|
||||
active = strings.ToUpper(config.OSS.Active)
|
||||
}
|
||||
var handler Uploader
|
||||
switch active {
|
||||
case Local:
|
||||
handler = NewLocalStorage(config)
|
||||
break
|
||||
case Minio:
|
||||
client, err := NewMiniOss(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler = client
|
||||
break
|
||||
case QiNiu:
|
||||
handler = NewQiNiuOss(config)
|
||||
break
|
||||
case AliYun:
|
||||
client, err := NewAliYunOss(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler = client
|
||||
break
|
||||
func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QiNiuOss) (*UploaderManager, error) {
|
||||
if sysConfig.OSS.Active == "" {
|
||||
sysConfig.OSS.Active = Local
|
||||
}
|
||||
|
||||
return &UploaderManager{handler: handler}, nil
|
||||
return &UploaderManager{
|
||||
active: sysConfig.OSS.Active,
|
||||
local: local,
|
||||
aliyun: aliyun,
|
||||
mini: mini,
|
||||
qiniu: qiniu,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *UploaderManager) GetUploadHandler() Uploader {
|
||||
return m.handler
|
||||
switch m.active {
|
||||
case Local:
|
||||
return m.local
|
||||
case AliYun:
|
||||
return m.aliyun
|
||||
case Minio:
|
||||
return m.mini
|
||||
case QiNiu:
|
||||
return m.qiniu
|
||||
}
|
||||
return m.local
|
||||
}
|
||||
|
||||
func (m *UploaderManager) UpdateConfig(config types.OSSConfig) {
|
||||
switch config.Active {
|
||||
case Local:
|
||||
m.local.UpdateConfig(config.Local)
|
||||
case AliYun:
|
||||
m.aliyun.UpdateConfig(config.AliYun)
|
||||
case Minio:
|
||||
m.mini.UpdateConfig(config.Minio)
|
||||
case QiNiu:
|
||||
m.qiniu.UpdateConfig(config.QiNiu)
|
||||
}
|
||||
m.active = config.Active
|
||||
}
|
||||
|
||||
@@ -12,129 +12,98 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/alipay"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/alipay"
|
||||
)
|
||||
|
||||
type AlipayService struct {
|
||||
config *types.AlipayConfig
|
||||
client *alipay.Client
|
||||
config *types.AlipayConfig
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
|
||||
config := appConfig.AlipayConfig
|
||||
func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
|
||||
config := sysConfig.Payment.Alipay
|
||||
if !config.Enabled {
|
||||
logger.Info("Disabled Alipay service")
|
||||
return nil, nil
|
||||
logger.Debug("Disabled Alipay service")
|
||||
}
|
||||
priKey, err := readKey(config.PrivateKey)
|
||||
|
||||
service := &AlipayService{config: &config}
|
||||
if config.Enabled {
|
||||
err := service.UpdateConfig(&config)
|
||||
if err != nil {
|
||||
logger.Errorf("支付宝服务初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
|
||||
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
||||
return fmt.Errorf("error with initialize alipay service: %v", err)
|
||||
}
|
||||
|
||||
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
|
||||
s.client = client
|
||||
s.config = config
|
||||
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||
logger.Info("Alipay Debug mode is enabled")
|
||||
client.DebugSwitch = gopay.DebugOn
|
||||
}
|
||||
|
||||
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
|
||||
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
|
||||
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
|
||||
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
|
||||
|
||||
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
|
||||
return nil, fmt.Errorf("error with load payment public key: %v", err)
|
||||
}
|
||||
|
||||
return &AlipayService{config: &config, client: client}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
type AlipayParams struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
Subject string `json:"subject"`
|
||||
TotalFee string `json:"total_fee"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", params.Subject)
|
||||
bm.Set("out_trade_no", params.OutTradeNo)
|
||||
bm.Set("quit_url", params.ReturnURL)
|
||||
bm.Set("total_amount", params.TotalFee)
|
||||
bm.Set("product_code", "QUICK_WAP_WAY")
|
||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
|
||||
}
|
||||
|
||||
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
|
||||
func (s *AlipayService) Pay(params PayRequest) (string, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("subject", params.Subject)
|
||||
bm.Set("out_trade_no", params.OutTradeNo)
|
||||
bm.Set("total_amount", params.TotalFee)
|
||||
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
|
||||
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
|
||||
return s.client.TradeWapPay(context.Background(), bm)
|
||||
}
|
||||
|
||||
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
|
||||
}
|
||||
|
||||
switch rsp.Response.TradeStatus {
|
||||
case "TRADE_SUCCESS":
|
||||
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
|
||||
return OrderInfo{
|
||||
OutTradeNo: rsp.Response.OutTradeNo,
|
||||
TradeId: rsp.Response.TradeNo,
|
||||
Amount: rsp.Response.TotalAmount,
|
||||
Status: Success,
|
||||
PayTime: rsp.Response.SendPayDate,
|
||||
}, nil
|
||||
case "TRADE_CLOSED":
|
||||
return OrderInfo{Status: Closed}, nil
|
||||
default:
|
||||
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
|
||||
func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
|
||||
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "error with parse notify request: " + err.Error(),
|
||||
}
|
||||
return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
|
||||
}
|
||||
|
||||
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "error with verify sign: " + err.Error(),
|
||||
}
|
||||
return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
|
||||
}
|
||||
|
||||
return s.TradeQuery(request.Form.Get("out_trade_no"))
|
||||
return s.Query(request.Form.Get("out_trade_no"))
|
||||
}
|
||||
|
||||
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("out_trade_no", outTradeNo)
|
||||
|
||||
//查询订单
|
||||
rsp, err := s.client.TradeQuery(context.Background(), bm)
|
||||
if err != nil {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
|
||||
return NotifyVo{
|
||||
Status: Success,
|
||||
OutTradeNo: rsp.Response.OutTradeNo,
|
||||
TradeId: rsp.Response.TradeNo,
|
||||
Amount: rsp.Response.TotalAmount,
|
||||
Subject: rsp.Response.Subject,
|
||||
Message: "OK",
|
||||
}
|
||||
} else {
|
||||
return NotifyVo{
|
||||
Status: Failure,
|
||||
Message: "异步查询验证订单信息发生错误" + outTradeNo,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readKey(filename string) (string, error) {
|
||||
data, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
var _ PayService = (*AlipayService)(nil)
|
||||
|
||||
@@ -22,41 +22,30 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// GeekPayService Geek 支付服务
|
||||
type GeekPayService struct {
|
||||
config *types.GeekPayConfig
|
||||
// EPayService 易支付服务
|
||||
type EPayService struct {
|
||||
config *types.EpayConfig
|
||||
}
|
||||
|
||||
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
|
||||
return &GeekPayService{
|
||||
config: &appConfig.GeekPayConfig,
|
||||
func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
|
||||
return &EPayService{
|
||||
config: &sysConfig.Payment.Epay,
|
||||
}
|
||||
}
|
||||
|
||||
type GeekPayParams struct {
|
||||
Method string `json:"method"` // 接口类型
|
||||
Device string `json:"device"` // 设备类型
|
||||
Type string `json:"type"` // 支付方式
|
||||
OutTradeNo string `json:"out_trade_no"` // 商户订单号
|
||||
Name string `json:"name"` // 商品名称
|
||||
Money string `json:"money"` // 商品金额
|
||||
ClientIP string `json:"clientip"` //用户IP地址
|
||||
SubOpenId string `json:"sub_openid"` // 微信用户 openid,仅小程序支付需要
|
||||
SubAppId string `json:"sub_appid"` // 小程序 AppId,仅小程序支付需要
|
||||
NotifyURL string `json:"notify_url"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
// Pay 支付订单
|
||||
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
||||
func (s *EPayService) Pay(params PayRequest) (string, error) {
|
||||
p := map[string]string{
|
||||
"pid": s.config.AppId,
|
||||
//"method": params.Method,
|
||||
"pid": s.config.AppId,
|
||||
"device": params.Device,
|
||||
"type": params.Type,
|
||||
"type": params.PayWay,
|
||||
"out_trade_no": params.OutTradeNo,
|
||||
"name": params.Name,
|
||||
"money": params.Money,
|
||||
"name": params.Subject,
|
||||
"money": params.TotalFee,
|
||||
"clientip": params.ClientIP,
|
||||
"notify_url": params.NotifyURL,
|
||||
"return_url": params.ReturnURL,
|
||||
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
|
||||
}
|
||||
p["sign"] = s.Sign(p)
|
||||
p["sign_type"] = "MD5"
|
||||
return s.sendRequest(s.config.ApiURL, p)
|
||||
resp, err := s.sendRequest(s.config.ApiURL, p)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.Code != 1 {
|
||||
return "", errors.New(resp.Msg)
|
||||
}
|
||||
if resp.PayURL != "" {
|
||||
return resp.PayURL, nil
|
||||
} else {
|
||||
return resp.QrCode, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GeekPayService) Sign(params map[string]string) string {
|
||||
func (s *EPayService) Sign(params map[string]string) string {
|
||||
// 按字母顺序排序参数
|
||||
var keys []string
|
||||
for k := range params {
|
||||
@@ -100,7 +100,7 @@ type GeekPayResp struct {
|
||||
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
|
||||
}
|
||||
|
||||
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||
func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
|
||||
form := url.Values{}
|
||||
for k, v := range params {
|
||||
form.Add(k, v)
|
||||
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("act", "order")
|
||||
params.Set("pid", s.config.AppId)
|
||||
params.Set("key", s.config.PrivateKey)
|
||||
params.Set("out_trade_no", outTradeNo)
|
||||
|
||||
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
|
||||
|
||||
tr := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
}
|
||||
client := &http.Client{Transport: tr}
|
||||
resp, err := client.Get(apiURL)
|
||||
if err != nil {
|
||||
return OrderInfo{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return OrderInfo{}, err
|
||||
}
|
||||
logger.Debugf(string(body))
|
||||
|
||||
var result struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Status string `json:"status"`
|
||||
Name string `json:"name"`
|
||||
Money string `json:"money"`
|
||||
EndTime string `json:"endtime"`
|
||||
TradeNo string `json:"trade_no"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return OrderInfo{}, errors.New("订单查询响应解析失败")
|
||||
}
|
||||
if result.Code != 1 {
|
||||
return OrderInfo{}, errors.New(result.Msg)
|
||||
}
|
||||
logger.Debugf("订单信息:%+v", result)
|
||||
orderInfo := OrderInfo{
|
||||
OutTradeNo: outTradeNo,
|
||||
TradeId: result.TradeNo,
|
||||
Amount: result.Money,
|
||||
PayTime: result.EndTime,
|
||||
}
|
||||
if result.Status == "1" {
|
||||
orderInfo.Status = Success
|
||||
} else {
|
||||
orderInfo.Status = Failure
|
||||
}
|
||||
return orderInfo, nil
|
||||
}
|
||||
|
||||
var _ PayService = (*EPayService)(nil)
|
||||
@@ -1,171 +0,0 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type HuPiPayService struct {
|
||||
appId string
|
||||
appSecret string
|
||||
apiURL string
|
||||
}
|
||||
|
||||
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
|
||||
return &HuPiPayService{
|
||||
appId: config.HuPiPayConfig.AppId,
|
||||
appSecret: config.HuPiPayConfig.AppSecret,
|
||||
apiURL: config.HuPiPayConfig.ApiURL,
|
||||
}
|
||||
}
|
||||
|
||||
type HuPiPayParams struct {
|
||||
AppId string `json:"appid"`
|
||||
Version string `json:"version"`
|
||||
TradeOrderId string `json:"trade_order_id"`
|
||||
TotalFee string `json:"total_fee"`
|
||||
Title string `json:"title"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
WapName string `json:"wap_name"`
|
||||
CallbackURL string `json:"callback_url"`
|
||||
Time string `json:"time"`
|
||||
NonceStr string `json:"nonce_str"`
|
||||
Type string `json:"type"`
|
||||
WapUrl string `json:"wap_url"`
|
||||
}
|
||||
|
||||
type HuPiPayResp struct {
|
||||
Openid interface{} `json:"openid"`
|
||||
UrlQrcode string `json:"url_qrcode"`
|
||||
URL string `json:"url"`
|
||||
ErrCode int `json:"errcode"`
|
||||
ErrMsg string `json:"errmsg,omitempty"`
|
||||
}
|
||||
|
||||
// Pay 执行支付请求操作
|
||||
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
|
||||
data := url.Values{}
|
||||
simple := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
params.AppId = s.appId
|
||||
params.Time = simple
|
||||
params.NonceStr = simple
|
||||
encode := utils.JsonEncode(params)
|
||||
m := make(map[string]string)
|
||||
_ = utils.JsonDecode(encode, &m)
|
||||
for k, v := range m {
|
||||
data.Add(k, fmt.Sprintf("%v", v))
|
||||
}
|
||||
// 生成签名
|
||||
data.Add("hash", s.Sign(data))
|
||||
// 发送支付请求
|
||||
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
|
||||
resp, err := http.PostForm(apiURL, data)
|
||||
if err != nil {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
all, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var res HuPiPayResp
|
||||
err = utils.JsonDecode(string(all), &res)
|
||||
if err != nil {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
|
||||
}
|
||||
|
||||
if res.ErrCode != 0 {
|
||||
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// Sign 签名方法
|
||||
func (s *HuPiPayService) Sign(params url.Values) string {
|
||||
params.Del(`Sign`)
|
||||
var keys = make([]string, 0, 0)
|
||||
for key := range params {
|
||||
if params.Get(key) != `` {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
var pList = make([]string, 0, 0)
|
||||
for _, key := range keys {
|
||||
var value = strings.TrimSpace(params.Get(key))
|
||||
if len(value) > 0 {
|
||||
pList = append(pList, key+"="+value)
|
||||
}
|
||||
}
|
||||
var src = strings.Join(pList, "&")
|
||||
src += s.appSecret
|
||||
|
||||
md5bs := md5.Sum([]byte(src))
|
||||
return hex.EncodeToString(md5bs[:])
|
||||
}
|
||||
|
||||
// Check 校验订单状态
|
||||
func (s *HuPiPayService) Check(outTradeNo string) error {
|
||||
data := url.Values{}
|
||||
data.Add("appid", s.appId)
|
||||
data.Add("out_trade_order", outTradeNo)
|
||||
stamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
data.Add("time", stamp)
|
||||
data.Add("nonce_str", stamp)
|
||||
data.Add("hash", s.Sign(data))
|
||||
|
||||
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
|
||||
resp, err := http.PostForm(apiURL, data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with http reqeust: %v", err)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with reading response: %v", err)
|
||||
}
|
||||
|
||||
var r struct {
|
||||
ErrCode int `json:"errcode"`
|
||||
Data struct {
|
||||
Status string `json:"status"`
|
||||
OpenOrderId string `json:"open_order_id"`
|
||||
} `json:"data,omitempty"`
|
||||
ErrMsg string `json:"errmsg"`
|
||||
Hash string `json:"hash"`
|
||||
}
|
||||
err = utils.JsonDecode(string(body), &r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with decode response: %v", err)
|
||||
}
|
||||
|
||||
if r.ErrCode == 0 && r.Data.Status == "OD" {
|
||||
return nil
|
||||
} else {
|
||||
logger.Debugf("%+v", r)
|
||||
return errors.New("order not paid:" + r.ErrMsg)
|
||||
}
|
||||
}
|
||||
54
api/service/payment/pay_service.go
Normal file
54
api/service/payment/pay_service.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package payment
|
||||
|
||||
// 支付渠道定义
|
||||
const PayChannelAL = "alipay" // 支付宝
|
||||
const PayChannelWX = "wxpay" // 微信支付
|
||||
const PayChannelEpay = "epay" // 易支付
|
||||
|
||||
// 支付方式
|
||||
const PayWayAL = "alipay"
|
||||
const PayWayWX = "wxpay"
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
Failure = 1
|
||||
Closed = 2
|
||||
)
|
||||
|
||||
type PayRequest struct {
|
||||
OutTradeNo string // 商户订单号
|
||||
Subject string // 商品名称
|
||||
TotalFee string // 商品金额
|
||||
ReturnURL string // 回调地址
|
||||
NotifyURL string // 回调地址
|
||||
|
||||
// 易支付专有参数
|
||||
Method string // 接口类型
|
||||
Device string // 设备类型
|
||||
PayWay string // 支付方式
|
||||
ClientIP string //用户IP地址
|
||||
OpenID string // 用户openid
|
||||
|
||||
}
|
||||
|
||||
type OrderInfo struct {
|
||||
Mchid string // 商户号
|
||||
OutTradeNo string // 商户订单号
|
||||
TradeId string // 交易号
|
||||
Amount string // 金额
|
||||
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
|
||||
PayTime string // 完成支付时间
|
||||
}
|
||||
|
||||
func (o OrderInfo) Closed() bool {
|
||||
return o.Status == Closed
|
||||
}
|
||||
|
||||
func (o OrderInfo) Success() bool {
|
||||
return o.Status == Success
|
||||
}
|
||||
|
||||
type PayService interface {
|
||||
Pay(params PayRequest) (string, error) // 生成支付链接
|
||||
Query(outTradeNo string) (OrderInfo, error) // 查询订单
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package payment
|
||||
|
||||
type NotifyVo struct {
|
||||
Status int
|
||||
OutTradeNo string // 商户订单号
|
||||
TradeId string // 交易ID
|
||||
Amount string // 交易金额
|
||||
Message string
|
||||
Subject string
|
||||
}
|
||||
|
||||
func (v NotifyVo) Success() bool {
|
||||
return v.Status == Success
|
||||
}
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
Failure = 1
|
||||
)
|
||||
@@ -1,144 +0,0 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/wechat/v3"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type WechatPayService struct {
|
||||
config *types.WechatPayConfig
|
||||
client *wechat.ClientV3
|
||||
}
|
||||
|
||||
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
|
||||
config := appConfig.WechatPayConfig
|
||||
if !config.Enabled {
|
||||
logger.Info("Disabled WechatPay service")
|
||||
return nil, nil
|
||||
}
|
||||
priKey, err := readKey(config.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with read App Private key: %v", err)
|
||||
}
|
||||
|
||||
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
|
||||
}
|
||||
err = client.AutoVerifySign()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
|
||||
}
|
||||
//client.DebugSwitch = gopay.DebugOn
|
||||
|
||||
return &WechatPayService{config: &config, client: client}, nil
|
||||
}
|
||||
|
||||
type WechatPayParams struct {
|
||||
OutTradeNo string `json:"out_trade_no"`
|
||||
TotalFee int `json:"total_fee"`
|
||||
Subject string `json:"subject"`
|
||||
ClientIP string `json:"client_ip"`
|
||||
ReturnURL string `json:"return_url"`
|
||||
NotifyURL string `json:"notify_url"`
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", params.TotalFee).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
|
||||
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.CodeUrl, nil
|
||||
}
|
||||
|
||||
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", params.TotalFee).
|
||||
Set("currency", "CNY")
|
||||
}).
|
||||
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("payer_client_ip", params.ClientIP).
|
||||
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("type", "Wap")
|
||||
})
|
||||
})
|
||||
|
||||
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.H5Url, nil
|
||||
}
|
||||
|
||||
type NotifyResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `xml:"message"`
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
|
||||
notifyReq, err := wechat.V3ParseNotify(request)
|
||||
if err != nil {
|
||||
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
|
||||
}
|
||||
|
||||
// TODO: 这里验签程序有 Bug,一直报错:crypto/rsa: verification error,先暂时取消验签
|
||||
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
|
||||
//if err != nil {
|
||||
// return fmt.Errorf("error with client v3 verify sign: %v", err)
|
||||
//}
|
||||
|
||||
// 解密支付密文,验证订单信息
|
||||
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
||||
if err != nil {
|
||||
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
|
||||
}
|
||||
|
||||
return NotifyVo{
|
||||
Status: Success,
|
||||
OutTradeNo: result.OutTradeNo,
|
||||
TradeId: result.TransactionId,
|
||||
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
||||
}
|
||||
}
|
||||
217
api/service/payment/wxpay_service.go
Normal file
217
api/service/payment/wxpay_service.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package payment
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-pay/gopay"
|
||||
"github.com/go-pay/gopay/wechat/v3"
|
||||
)
|
||||
|
||||
type WxPayService struct {
|
||||
config *types.WxPayConfig
|
||||
client *wechat.ClientV3
|
||||
}
|
||||
|
||||
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
|
||||
config := sysConfig.Payment.WxPay
|
||||
if !config.Enabled {
|
||||
logger.Debug("Disabled WechatPay service")
|
||||
}
|
||||
|
||||
service := &WxPayService{config: &config}
|
||||
if config.Enabled {
|
||||
err := service.UpdateConfig(&config)
|
||||
if err != nil {
|
||||
logger.Errorf("微信支付服务初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
|
||||
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with initialize WechatPay service: %v", err)
|
||||
}
|
||||
err = client.AutoVerifySign()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error with autoVerifySign: %v", err)
|
||||
}
|
||||
s.client = client
|
||||
if os.Getenv("GEEKAI_DEBUG") == "true" {
|
||||
logger.Info("WechatPay Debug mode is enabled")
|
||||
client.DebugSwitch = gopay.DebugOn
|
||||
}
|
||||
s.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *WxPayService) Pay(params PayRequest) (string, error) {
|
||||
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// 初始化 BodyMap
|
||||
bm := make(gopay.BodyMap)
|
||||
bm.Set("appid", s.config.AppId).
|
||||
Set("mchid", s.config.MchId).
|
||||
Set("description", params.Subject).
|
||||
Set("out_trade_no", params.OutTradeNo).
|
||||
Set("time_expire", expire).
|
||||
Set("notify_url", params.NotifyURL).
|
||||
SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
|
||||
Set("currency", "CNY")
|
||||
})
|
||||
logger.Debugf("wxpay params: %+v", bm)
|
||||
if params.Device == "mobile" {
|
||||
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
bm.Set("payer_client_ip", params.ClientIP)
|
||||
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
|
||||
bm.Set("openid", params.OpenID)
|
||||
})
|
||||
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.PrepayId, nil
|
||||
} else if params.Device == "pc" {
|
||||
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
}
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
}
|
||||
return wxRsp.Response.CodeUrl, nil
|
||||
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
|
||||
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
|
||||
}
|
||||
|
||||
if wxRsp.Code != wechat.Success {
|
||||
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
|
||||
}
|
||||
|
||||
if wxRsp.Response.TradeState == "CLOSED" {
|
||||
return OrderInfo{Status: Closed}, nil
|
||||
}
|
||||
|
||||
orderInfo := OrderInfo{
|
||||
OutTradeNo: wxRsp.Response.OutTradeNo,
|
||||
TradeId: wxRsp.Response.TransactionId,
|
||||
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
|
||||
PayTime: wxRsp.Response.SuccessTime,
|
||||
}
|
||||
if wxRsp.Response.TradeState == "SUCCESS" {
|
||||
orderInfo.Status = Success
|
||||
} else {
|
||||
orderInfo.Status = Failure
|
||||
}
|
||||
return orderInfo, nil
|
||||
}
|
||||
|
||||
// TradeVerify 交易验证
|
||||
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
|
||||
notifyReq, err := wechat.V3ParseNotify(request)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
|
||||
}
|
||||
|
||||
// 解密支付密文,验证订单信息
|
||||
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
|
||||
if err != nil {
|
||||
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
|
||||
}
|
||||
|
||||
return OrderInfo{
|
||||
Status: Success,
|
||||
OutTradeNo: result.OutTradeNo,
|
||||
TradeId: result.TransactionId,
|
||||
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
|
||||
PayTime: result.SuccessTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
|
||||
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// // 初始化 BodyMap
|
||||
// bm := make(gopay.BodyMap)
|
||||
// bm.Set("appid", s.config.AppId).
|
||||
// Set("mchid", s.config.MchId).
|
||||
// Set("description", params.Subject).
|
||||
// Set("out_trade_no", params.OutTradeNo).
|
||||
// Set("time_expire", expire).
|
||||
// Set("notify_url", params.NotifyURL).
|
||||
// SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
// bm.Set("total", params.TotalFee).
|
||||
// Set("currency", "CNY")
|
||||
// })
|
||||
|
||||
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
|
||||
// }
|
||||
// if wxRsp.Code != wechat.Success {
|
||||
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
|
||||
// }
|
||||
// return wxRsp.Response.CodeUrl, nil
|
||||
// }
|
||||
|
||||
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
|
||||
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
|
||||
// // 初始化 BodyMap
|
||||
// bm := make(gopay.BodyMap)
|
||||
// bm.Set("appid", s.config.AppId).
|
||||
// Set("mchid", s.config.MchId).
|
||||
// Set("description", params.Subject).
|
||||
// Set("out_trade_no", params.OutTradeNo).
|
||||
// Set("time_expire", expire).
|
||||
// Set("notify_url", params.NotifyURL).
|
||||
// SetBodyMap("amount", func(bm gopay.BodyMap) {
|
||||
// bm.Set("total", params.TotalFee).
|
||||
// Set("currency", "CNY")
|
||||
// }).
|
||||
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
|
||||
// bm.Set("payer_client_ip", params.ClientIP).
|
||||
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
|
||||
// bm.Set("type", "Wap")
|
||||
// })
|
||||
// })
|
||||
|
||||
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
|
||||
// if err != nil {
|
||||
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
|
||||
// }
|
||||
// if wxRsp.Code != wechat.Success {
|
||||
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
|
||||
// }
|
||||
// return wxRsp.Response.H5Url, nil
|
||||
// }
|
||||
|
||||
// type NotifyResponse struct {
|
||||
// Code string `json:"code"`
|
||||
// Message string `xml:"message"`
|
||||
// }
|
||||
|
||||
var _ PayService = (*WxPayService)(nil)
|
||||
@@ -10,36 +10,57 @@ package sms
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
|
||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||
)
|
||||
|
||||
type AliYunSmsService struct {
|
||||
config *types.SmsConfigAli
|
||||
config types.SmsConfigAli
|
||||
client *dysmsapi.Client
|
||||
domain string
|
||||
zoneId string
|
||||
}
|
||||
|
||||
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
|
||||
config := &appConfig.SMS.Ali
|
||||
// 创建阿里云短信客户端
|
||||
func NewAliYunSmsService(sysConfig *types.SystemConfig) (*AliYunSmsService, error) {
|
||||
config := sysConfig.SMS.Ali
|
||||
domain := "dysmsapi.aliyuncs.com"
|
||||
zoneId := "cn-hangzhou"
|
||||
|
||||
s := AliYunSmsService{
|
||||
config: config,
|
||||
domain: domain,
|
||||
zoneId: zoneId,
|
||||
}
|
||||
if sysConfig.SMS.Active == Ali {
|
||||
err := s.UpdateConfig(config)
|
||||
if err != nil {
|
||||
logger.Errorf("阿里云短信初始化失败: %v", err)
|
||||
}
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
func (s *AliYunSmsService) UpdateConfig(config types.SmsConfigAli) error {
|
||||
client, err := dysmsapi.NewClientWithAccessKey(
|
||||
"cn-hangzhou",
|
||||
s.zoneId,
|
||||
config.AccessKey,
|
||||
config.AccessSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create client: %v", err)
|
||||
return fmt.Errorf("failed to create client: %v", err)
|
||||
}
|
||||
|
||||
return &AliYunSmsService{
|
||||
config: config,
|
||||
client: client,
|
||||
}, nil
|
||||
s.client = client
|
||||
s.config = config
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
|
||||
if s.client == nil {
|
||||
return fmt.Errorf("阿里云短信服务未初始化")
|
||||
}
|
||||
// 创建短信请求并设置参数
|
||||
request := dysmsapi.CreateSendSmsRequest()
|
||||
request.Scheme = "https"
|
||||
request.Domain = s.config.Domain
|
||||
request.Domain = s.domain
|
||||
request.PhoneNumbers = mobile
|
||||
request.SignName = s.config.Sign
|
||||
request.TemplateCode = s.config.CodeTempId
|
||||
|
||||
@@ -19,20 +19,21 @@ import (
|
||||
)
|
||||
|
||||
type BaoSmsService struct {
|
||||
config *types.SmsConfigBao
|
||||
config types.SmsConfigBao
|
||||
domain string
|
||||
}
|
||||
|
||||
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
|
||||
config := appConfig.SMS.Bao
|
||||
if config.Domain == "" { // use default domain
|
||||
config.Domain = "api.smsbao.com"
|
||||
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
|
||||
}
|
||||
func NewBaoSmsService(sysConfig *types.SystemConfig) *BaoSmsService {
|
||||
return &BaoSmsService{
|
||||
config: &config,
|
||||
config: sysConfig.SMS.Bao,
|
||||
domain: "api.smsbao.com",
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaoSmsService) UpdateConfig(config types.SmsConfigBao) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
var errMsg = map[string]string{
|
||||
"0": "短信发送成功",
|
||||
"-1": "参数不全",
|
||||
@@ -56,7 +57,7 @@ func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
|
||||
params.Set("m", mobile)
|
||||
params.Set("c", content)
|
||||
|
||||
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
|
||||
apiURL := fmt.Sprintf("https://%s/sms?%s", s.domain, params.Encode())
|
||||
response, err := http.Get(apiURL)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -7,8 +7,8 @@ package sms
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
const Ali = "ALI"
|
||||
const Bao = "BAO"
|
||||
const Ali = "aliyun"
|
||||
const Bao = "bao"
|
||||
|
||||
type Service interface {
|
||||
SendVerifyCode(mobile string, code int) error
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ServiceManager struct {
|
||||
handler Service
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
|
||||
active := Ali
|
||||
if config.SMS.Active != "" {
|
||||
active = strings.ToUpper(config.SMS.Active)
|
||||
}
|
||||
var handler Service
|
||||
switch active {
|
||||
case Ali:
|
||||
client, err := NewAliYunSmsService(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handler = client
|
||||
break
|
||||
case Bao:
|
||||
handler = NewSmsBaoSmsService(config)
|
||||
break
|
||||
}
|
||||
|
||||
return &ServiceManager{handler: handler}, nil
|
||||
}
|
||||
|
||||
func (m *ServiceManager) GetService() Service {
|
||||
return m.handler
|
||||
}
|
||||
54
api/service/sms/sms_manager.go
Normal file
54
api/service/sms/sms_manager.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package sms
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
)
|
||||
|
||||
type SmsManager struct {
|
||||
aliyun *AliYunSmsService
|
||||
bao *BaoSmsService
|
||||
active string
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewSmsManager(sysConfig *types.SystemConfig, aliyun *AliYunSmsService, bao *BaoSmsService) (*SmsManager, error) {
|
||||
|
||||
return &SmsManager{
|
||||
active: sysConfig.SMS.Active,
|
||||
aliyun: aliyun,
|
||||
bao: bao,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *SmsManager) GetService() Service {
|
||||
switch m.active {
|
||||
case Ali:
|
||||
return m.aliyun
|
||||
case Bao:
|
||||
return m.bao
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *SmsManager) SetActive(active string) {
|
||||
m.active = active
|
||||
}
|
||||
|
||||
func (m *SmsManager) UpdateConfig(config types.SMSConfig) {
|
||||
switch config.Active {
|
||||
case Ali:
|
||||
m.aliyun.UpdateConfig(config.Ali)
|
||||
case Bao:
|
||||
m.bao.UpdateConfig(config.Bao)
|
||||
}
|
||||
m.active = config.Active
|
||||
}
|
||||
@@ -27,6 +27,10 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SmtpService) UpdateConfig(config *types.SmtpConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *SmtpService) SendVerifyCode(to string, code int) error {
|
||||
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
|
||||
body := fmt.Sprintf("【%s】:您的验证码为 %d,请不要告诉他人。如非本人操作,请忽略此邮件。", s.config.AppName, code)
|
||||
|
||||
@@ -112,6 +112,10 @@ type RespVo struct {
|
||||
Message string `json:"message"`
|
||||
Data string `json:"data"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
@@ -126,7 +130,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"task_id": task.RefTaskId,
|
||||
"continue_clip_id": task.RefSongId,
|
||||
"continue_at": task.ExtendSecs,
|
||||
@@ -154,13 +158,14 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
logger.Debugf("API response: %s", string(body))
|
||||
err = json.Unmarshal(body, &res)
|
||||
if err != nil {
|
||||
return RespVo{}, fmt.Errorf("解析API数据失败:%v, %s", err, string(body))
|
||||
}
|
||||
|
||||
if res.Code != "success" {
|
||||
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
|
||||
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Error.Message)
|
||||
}
|
||||
// update the last_use_at for api key
|
||||
apiKey.LastUsedAt = time.Now().Unix()
|
||||
@@ -225,7 +230,7 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
|
||||
return RespVo{}, errors.New("no available API KEY for Suno")
|
||||
}
|
||||
|
||||
reqBody := map[string]interface{}{
|
||||
reqBody := map[string]any{
|
||||
"url": task.AudioURL,
|
||||
}
|
||||
|
||||
@@ -330,7 +335,13 @@ func (s *Service) SyncTaskProgress() {
|
||||
job.SongId = v.Id
|
||||
job.Duration = int(v.Metadata.Duration)
|
||||
job.Prompt = v.Metadata.Prompt
|
||||
job.Tags = v.Metadata.Tags
|
||||
// 修复 tags 字段过长导致插入数据库失败
|
||||
if len(v.Metadata.Tags) > 255 {
|
||||
job.Tags = v.Metadata.Tags[:255]
|
||||
} else {
|
||||
job.Tags = v.Metadata.Tags
|
||||
}
|
||||
|
||||
job.ModelName = v.ModelName
|
||||
job.RawData = utils.JsonEncode(v)
|
||||
job.CoverURL = v.ImageLargeUrl
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package service
|
||||
|
||||
import logger2 "geekai/logger"
|
||||
|
||||
const FailTaskProgress = 101
|
||||
const (
|
||||
TaskStatusRunning = "RUNNING"
|
||||
@@ -15,6 +17,8 @@ type NotifyMessage struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
|
||||
|
||||
const ImagePromptOptimizeTemplate = `
|
||||
|
||||
117
api/service/wxlogin_service.go
Normal file
117
api/service/wxlogin_service.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/utils"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type WxLoginService struct {
|
||||
config types.WxLoginConfig
|
||||
client *req.Client
|
||||
redisClient *redis.Client
|
||||
}
|
||||
|
||||
const loginStateKeyPrefix = "wx_login_state/"
|
||||
|
||||
type LoginStatus struct {
|
||||
Status string `json:"status"`
|
||||
OpenID string `json:"openid"`
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
const (
|
||||
LoginStatusPending = "pending"
|
||||
LoginStatusSuccess = "success"
|
||||
LoginStatusExpired = "expired" // 登录失效,需要重新登录
|
||||
)
|
||||
|
||||
func NewWxLoginService(config types.WxLoginConfig, redisClient *redis.Client) *WxLoginService {
|
||||
return &WxLoginService{
|
||||
config: config,
|
||||
client: req.C().SetTimeout(10 * time.Second),
|
||||
redisClient: redisClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WxLoginService) UpdateConfig(config types.WxLoginConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *WxLoginService) GetConfig() types.WxLoginConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
func (s *WxLoginService) SetConfig(config types.WxLoginConfig) {
|
||||
s.config = config
|
||||
}
|
||||
|
||||
func (s *WxLoginService) GetLoginQrCodeUrl(state string) (string, error) {
|
||||
if s.config.ApiKey == "" {
|
||||
return "", errors.New("无效的 API Key")
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/api/auth/wechat/login", types.GeekAPIURL)
|
||||
var res struct {
|
||||
Code types.BizCode `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data struct {
|
||||
Ticket string `json:"ticket"`
|
||||
Url string `json:"url"`
|
||||
} `json:"data"`
|
||||
}
|
||||
r, err := s.client.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetBody(map[string]string{
|
||||
"notify_url": s.config.NotifyURL,
|
||||
"state": state,
|
||||
}).
|
||||
SetSuccessResult(&res).Post(url)
|
||||
if err != nil || r.IsErrorState() {
|
||||
return "", fmt.Errorf("请求 API 失败:%v", err)
|
||||
}
|
||||
|
||||
if res.Code != types.Success {
|
||||
return "", fmt.Errorf("请求 API 失败:%s", res.Message)
|
||||
}
|
||||
|
||||
status := LoginStatus{
|
||||
Status: LoginStatusPending,
|
||||
OpenID: "",
|
||||
}
|
||||
s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour)
|
||||
|
||||
return res.Data.Url, nil
|
||||
}
|
||||
|
||||
func (s *WxLoginService) GetLoginStatus(state string) (*LoginStatus, error) {
|
||||
result, err := s.redisClient.Get(context.Background(), loginStateKeyPrefix+state).Result()
|
||||
if err != nil {
|
||||
return nil, errors.New("登录失败")
|
||||
}
|
||||
|
||||
var status LoginStatus
|
||||
err = utils.JsonDecode(result, &status)
|
||||
if err != nil {
|
||||
return nil, errors.New("登录失败")
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
func (s *WxLoginService) SetLoginStatus(state string, status LoginStatus) {
|
||||
s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour)
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
package service
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||
// * Use of this source code is governed by a Apache-2.0 license
|
||||
// * that can be found in the LICENSE file.
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"context"
|
||||
"geekai/core/types"
|
||||
logger2 "geekai/logger"
|
||||
|
||||
"github.com/xxl-job/xxl-job-executor-go"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type XXLJobExecutor struct {
|
||||
executor xxl.Executor
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor {
|
||||
if !config.XXLConfig.Enabled {
|
||||
logger.Info("XXL-JOB service is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
exec := xxl.NewExecutor(
|
||||
xxl.ServerAddr(config.XXLConfig.ServerAddr),
|
||||
xxl.AccessToken(config.XXLConfig.AccessToken), //请求令牌(默认为空)
|
||||
xxl.ExecutorIp(config.XXLConfig.ExecutorIp), //可自动获取
|
||||
xxl.ExecutorPort(config.XXLConfig.ExecutorPort), //默认9999(非必填)
|
||||
xxl.RegistryKey(config.XXLConfig.RegistryKey), //执行器名称
|
||||
xxl.SetLogger(&customLogger{}), //自定义日志
|
||||
)
|
||||
exec.Init()
|
||||
return &XXLJobExecutor{executor: exec, db: db}
|
||||
}
|
||||
|
||||
func (e *XXLJobExecutor) Run() error {
|
||||
e.executor.RegTask("ClearOrders", e.ClearOrders)
|
||||
return e.executor.Run()
|
||||
}
|
||||
|
||||
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
|
||||
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
|
||||
logger.Info("执行清理未支付订单...")
|
||||
|
||||
return "success"
|
||||
}
|
||||
|
||||
type customLogger struct{}
|
||||
|
||||
func (l *customLogger) Info(format string, a ...interface{}) {
|
||||
logger.Debugf(format, a...)
|
||||
}
|
||||
|
||||
func (l *customLogger) Error(format string, a ...interface{}) {
|
||||
logger.Errorf(format, a...)
|
||||
}
|
||||
Reference in New Issue
Block a user