From c0d17e132dc2a31456d89a42b269362847edbd3c Mon Sep 17 00:00:00 2001 From: MHSanaei Date: Thu, 2 Jul 2026 16:39:31 +0200 Subject: [PATCH] fix(job): batch ip-limit per-email lookups and persistence MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit processObserved paid four round-trips per observed email every 10s scan: an inbound-resolving join, a tracking-row read, an autocommit Save (one fsync each under synchronous=FULL), and — worst of all — a full JSON parse of the owning inbound's settings blob just to read that one client's limitIp. On a big single inbound that parse alone made a scan cost ~1.5s per online client. The scan now front-loads three chunked batch queries (clients.limit_ip, email->inbound through the client_inbounds relation keeping the lowest inbound id like the old First(), and the tracking rows) and writes every inbound_client_ips change inside one transaction, so M observed emails cost a handful of queries and a single fsync. The per-email LIKE fallback remains for emails missing from the relation, preserving the #4963 stale-email cleanup. limitIp now comes from the clients table (same source B3 gates on) instead of the settings blob, and xray disconnects for banned clients run after the commit so their network round-trips never extend the write transaction node syncs contend with. --- internal/web/job/check_client_ip_job.go | 306 +++++++++++++----- .../check_client_ip_job_integration_test.go | 12 +- 2 files changed, 229 insertions(+), 89 deletions(-) diff --git a/internal/web/job/check_client_ip_job.go b/internal/web/job/check_client_ip_job.go index 5faa76e35..e016f64f7 100644 --- a/internal/web/job/check_client_ip_job.go +++ b/internal/web/job/check_client_ip_job.go @@ -124,32 +124,187 @@ func (j *CheckClientIpJob) hasLimitIp() bool { return err == nil && probe > 0 } +const ipScanChunk = 400 + +func chunkEmails(s []string, size int) [][]string { + if len(s) == 0 { + return nil + } + chunks := make([][]string, 0, (len(s)+size-1)/size) + for size < len(s) { + s, chunks = s[size:], append(chunks, s[:size]) + } + return append(chunks, s) +} + +// loadClientLimits maps each observed email to its clients.limit_ip in a few +// chunked queries, replacing the per-email settings-JSON parse that previously +// resolved the limit. +func (j *CheckClientIpJob) loadClientLimits(emails []string) map[string]int { + db := database.GetDB() + out := make(map[string]int, len(emails)) + for _, batch := range chunkEmails(emails, ipScanChunk) { + var rows []struct { + Email string + LimitIp int + } + if err := db.Model(&model.ClientRecord{}). + Select("email, limit_ip"). + Where("email IN ?", batch). + Scan(&rows).Error; err != nil { + j.checkError(err) + continue + } + for _, r := range rows { + out[r.Email] = r.LimitIp + } + } + return out +} + +// loadInboundsByEmails resolves each email's owning inbound through the +// clients/client_inbounds relation in chunked queries. Like the old per-email +// First() it keeps the lowest inbound id when a client spans several inbounds. +func (j *CheckClientIpJob) loadInboundsByEmails(emails []string) map[string]*model.Inbound { + db := database.GetDB() + minInboundByEmail := make(map[string]int, len(emails)) + for _, batch := range chunkEmails(emails, ipScanChunk) { + var pairs []struct { + Email string + InboundId int + } + if err := db.Table("client_inbounds"). + Select("clients.email AS email, client_inbounds.inbound_id AS inbound_id"). + Joins("JOIN clients ON clients.id = client_inbounds.client_id"). + Where("clients.email IN ?", batch). + Scan(&pairs).Error; err != nil { + j.checkError(err) + return nil + } + for _, p := range pairs { + if cur, ok := minInboundByEmail[p.Email]; !ok || p.InboundId < cur { + minInboundByEmail[p.Email] = p.InboundId + } + } + } + if len(minInboundByEmail) == 0 { + return nil + } + + idSet := make(map[int]struct{}, len(minInboundByEmail)) + ids := make([]int, 0, len(minInboundByEmail)) + for _, id := range minInboundByEmail { + if _, seen := idSet[id]; !seen { + idSet[id] = struct{}{} + ids = append(ids, id) + } + } + sort.Ints(ids) + inboundsById := make(map[int]*model.Inbound, len(ids)) + for lo := 0; lo < len(ids); lo += ipScanChunk { + hi := min(lo+ipScanChunk, len(ids)) + var page []*model.Inbound + if err := db.Model(&model.Inbound{}).Where("id IN ?", ids[lo:hi]).Find(&page).Error; err != nil { + j.checkError(err) + return nil + } + for _, ib := range page { + inboundsById[ib.Id] = ib + } + } + + out := make(map[string]*model.Inbound, len(minInboundByEmail)) + for email, id := range minInboundByEmail { + if ib, ok := inboundsById[id]; ok { + out[email] = ib + } + } + return out +} + +func (j *CheckClientIpJob) loadClientIpRows(emails []string) map[string]*model.InboundClientIps { + db := database.GetDB() + out := make(map[string]*model.InboundClientIps, len(emails)) + for _, batch := range chunkEmails(emails, ipScanChunk) { + var rows []model.InboundClientIps + if err := db.Where("client_email IN ?", batch).Find(&rows).Error; err != nil { + j.checkError(err) + continue + } + for i := range rows { + out[rows[i].ClientEmail] = &rows[i] + } + } + return out +} + // processObserved runs collection + enforcement for one scan's observations // (email -> ip -> last-seen unix seconds). observedAreLive marks the // observations as live connections, which bypass the stale cutoff: a connection // that opened hours ago is still live even though its timestamp is old. The // online-stats API always reports live connections, so the job passes true. +// Lookups are batched up front and all inbound_client_ips writes share one +// transaction, so a scan costs a handful of queries and one fsync instead of +// several per observed email. func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, enforce, observedAreLive bool) bool { shouldCleanLog := false now := time.Now().Unix() + + emails := make([]string, 0, len(observed)) + for email := range observed { + emails = append(emails, email) + } + sort.Strings(emails) + + limitByEmail := j.loadClientLimits(emails) + inboundByEmail := j.loadInboundsByEmails(emails) + ipRowByEmail := j.loadClientIpRows(emails) + // attribution accumulates this scan's local observations per email so they can // be recorded under this panel's own guid for cross-node IP attribution. attribution := make(map[string][]model.ClientIpEntry, len(observed)) - for email, ipTimestamps := range observed { + + type pendingDisconnect struct { + inbound *model.Inbound + email string + } + var disconnects []pendingDisconnect + + db := database.GetDB() + tx := db.Begin() + if tx.Error != nil { + j.checkError(tx.Error) + return false + } + committed := false + defer func() { + if !committed { + tx.Rollback() + } + }() + + for _, email := range emails { + ipTimestamps := observed[email] // The observations can still reference a client that was just renamed // or deleted; its email no longer matches any inbound. Skip it (and // drop any orphaned tracking row) instead of recreating a row and - // logging an ERROR every run (#4963). - inbound, err := j.getInboundByEmail(email) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email) - j.delInboundClientIps(email) - } else { - j.checkError(err) + // logging an ERROR every run (#4963). The batch map resolves through + // the clients relation; the per-email fallback keeps its settings LIKE + // net for clients not yet present there. + inbound, ok := inboundByEmail[email] + if !ok { + var err error + inbound, err = j.getInboundByEmail(email) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email) + j.delInboundClientIps(tx, email) + } else { + j.checkError(err) + } + continue } - continue } // Convert to IPWithTimestamp slice @@ -170,13 +325,44 @@ func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, attribution[email] = attrEntries } - clientIpsRecord, err := j.getInboundClientIps(email) - if err != nil { - _ = j.addInboundClientIps(email, ipsWithTime) + clientIpsRecord, ok := ipRowByEmail[email] + if !ok { + jsonIps, err := json.Marshal(ipsWithTime) + if err != nil { + j.checkError(err) + continue + } + if err := tx.Save(&model.InboundClientIps{ClientEmail: email, Ips: string(jsonIps)}).Error; err != nil { + j.checkError(err) + } continue } - shouldCleanLog = j.updateInboundClientIps(clientIpsRecord, inbound, email, ipsWithTime, enforce, observedAreLive) || shouldCleanLog + cleaned, banned := j.updateInboundClientIps(tx, clientIpsRecord, inbound, email, limitByEmail[email], ipsWithTime, enforce, observedAreLive) + shouldCleanLog = cleaned || shouldCleanLog + if banned { + disconnects = append(disconnects, pendingDisconnect{inbound: inbound, email: email}) + } + } + + if err := tx.Commit().Error; err != nil { + j.checkError(err) + return shouldCleanLog + } + committed = true + + // Xray disconnects run after the commit so their network round-trips never + // extend the scan's write transaction (node syncs upsert the same table). + clientsCache := make(map[int][]model.Client) + for _, d := range disconnects { + clients, cached := clientsCache[d.inbound.Id] + if !cached { + settings := map[string][]model.Client{} + _ = json.Unmarshal([]byte(d.inbound.Settings), &settings) + clients = settings["clients"] + clientsCache[d.inbound.Id] = clients + } + j.disconnectClientTemporarily(d.inbound, d.email, clients) } j.recordLocalAttribution(attribution) @@ -275,81 +461,34 @@ func (j *CheckClientIpJob) checkError(e error) { } } -func (j *CheckClientIpJob) getInboundClientIps(clientEmail string) (*model.InboundClientIps, error) { - db := database.GetDB() - InboundClientIps := &model.InboundClientIps{} - err := db.Model(model.InboundClientIps{}).Where("client_email = ?", clientEmail).First(InboundClientIps).Error - if err != nil { - return nil, err - } - return InboundClientIps, nil -} - -func (j *CheckClientIpJob) addInboundClientIps(clientEmail string, ipsWithTime []IPWithTimestamp) error { - inboundClientIps := &model.InboundClientIps{} - jsonIps, err := json.Marshal(ipsWithTime) - j.checkError(err) - - inboundClientIps.ClientEmail = clientEmail - inboundClientIps.Ips = string(jsonIps) - - db := database.GetDB() - tx := db.Begin() - - defer func() { - if err == nil { - tx.Commit() - } else { - tx.Rollback() - } - }() - - err = tx.Save(inboundClientIps).Error - if err != nil { - return err - } - return nil -} - // delInboundClientIps drops the inbound_client_ips tracking row for an email // that no longer maps to any inbound (a renamed or deleted client), so stale // access-log entries don't keep a ghost row alive (#4963). -func (j *CheckClientIpJob) delInboundClientIps(clientEmail string) { - db := database.GetDB() - if err := db.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil { +func (j *CheckClientIpJob) delInboundClientIps(tx *gorm.DB, clientEmail string) { + if err := tx.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil { j.checkError(err) } } -func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) bool { +// updateInboundClientIps merges one email's observed IPs into its tracking row +// and applies the IP limit. limitIp comes from the caller (the clients table); +// writes go through the caller's transaction. banned=true asks the caller to +// disconnect the client after the transaction commits. +func (j *CheckClientIpJob) updateInboundClientIps(tx *gorm.DB, inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, limitIp int, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) (shouldCleanLog, banned bool) { if inbound.Settings == "" { logger.Debug("wrong data:", inbound) - return false + return false, false } - settings := map[string][]model.Client{} - _ = json.Unmarshal([]byte(inbound.Settings), &settings) - clients := settings["clients"] - - // Find the client's IP limit - var limitIp int - var clientFound bool - for _, client := range clients { - if client.Email == clientEmail { - limitIp = client.LimitIP - clientFound = true - break - } - } - - if !enforce || !clientFound || limitIp <= 0 || !inbound.Enable { - // Nothing to enforce (collection-only run, no limit, client missing, or - // inbound disabled): record the observed IPs for the panel and return. + if !enforce || limitIp <= 0 || !inbound.Enable { + // Nothing to enforce (collection-only run, no limit on the clients row, + // or inbound disabled): record the observed IPs for the panel and return. jsonIps, _ := json.Marshal(newIpsWithTime) inboundClientIps.Ips = string(jsonIps) - db := database.GetDB() - db.Save(inboundClientIps) - return false + if err := tx.Save(inboundClientIps).Error; err != nil { + logger.Error("failed to save inboundClientIps:", err) + } + return false, false } // Parse old IPs from database @@ -368,18 +507,18 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun } liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan) - shouldCleanLog := false j.disAllowedIps = []string{} // historical db-only ips are excluded from this count on purpose. keptLive, bannedLive := selectIpsToBan(liveIps, limitIp) if len(bannedLive) > 0 { shouldCleanLog = true + banned = true logIpFile, err := os.OpenFile(xray.GetIPLimitLogPath(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { logger.Errorf("failed to open IP limit log file: %s", err) - return false + return false, false } defer logIpFile.Close() ipLogger := log.New(logIpFile, "", log.LstdFlags) @@ -392,9 +531,6 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun j.disAllowedIps = append(j.disAllowedIps, ipTime.IP) ipLogger.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp) } - - // force xray to drop existing connections from banned ips - j.disconnectClientTemporarily(inbound, clientEmail, clients) } // keep kept-live + historical in the blob so the panel keeps showing @@ -406,18 +542,16 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun jsonIps, _ := json.Marshal(dbIps) inboundClientIps.Ips = string(jsonIps) - db := database.GetDB() - err := db.Save(inboundClientIps).Error - if err != nil { + if err := tx.Save(inboundClientIps).Error; err != nil { logger.Error("failed to save inboundClientIps:", err) - return false + return false, banned } if len(j.disAllowedIps) > 0 { logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d old IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps)) } - return shouldCleanLog + return shouldCleanLog, banned } // disconnectClientTemporarily removes and re-adds a client to force disconnect banned connections diff --git a/internal/web/job/check_client_ip_job_integration_test.go b/internal/web/job/check_client_ip_job_integration_test.go index e28b9cc04..c0265ab87 100644 --- a/internal/web/job/check_client_ip_job_integration_test.go +++ b/internal/web/job/check_client_ip_job_integration_test.go @@ -95,7 +95,7 @@ func seedInboundOnlyWithClient(t *testing.T, tag, email string, limitIp int) *mo func seedLinkedInboundWithClient(t *testing.T, tag, email string, limitIp int) *model.Inbound { t.Helper() inbound := seedInboundOnlyWithClient(t, tag, email, limitIp) - client := &model.ClientRecord{Email: email} + client := &model.ClientRecord{Email: email, LimitIP: limitIp} if err := database.GetDB().Create(client).Error; err != nil { t.Fatalf("seed client record: %v", err) } @@ -206,11 +206,14 @@ func TestUpdateInboundClientIps_LiveIpNotBannedByStillFreshHistoricals(t *testin if err != nil { t.Fatalf("getInboundByEmail: %v", err) } - shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false) + shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 3, live, true, false) if shouldCleanLog { t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3") } + if banned { + t.Fatalf("banned must be false with 1 live ip under limit 3") + } if len(j.disAllowedIps) != 0 { t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps) } @@ -259,11 +262,14 @@ func TestUpdateInboundClientIps_ExcessLiveIpIsStillBanned(t *testing.T) { if err != nil { t.Fatalf("getInboundByEmail: %v", err) } - shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false) + shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 1, live, true, false) if !shouldCleanLog { t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit") } + if !banned { + t.Fatalf("banned must be true when the live set exceeds the limit") + } if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "10.1.0.1" { t.Fatalf("expected 10.1.0.1 to be banned; disAllowedIps = %v", j.disAllowedIps) }