diff --git a/internal/web/job/check_client_ip_job.go b/internal/web/job/check_client_ip_job.go index 5faa76e35..e016f64f7 100644 --- a/internal/web/job/check_client_ip_job.go +++ b/internal/web/job/check_client_ip_job.go @@ -124,32 +124,187 @@ func (j *CheckClientIpJob) hasLimitIp() bool { return err == nil && probe > 0 } +const ipScanChunk = 400 + +func chunkEmails(s []string, size int) [][]string { + if len(s) == 0 { + return nil + } + chunks := make([][]string, 0, (len(s)+size-1)/size) + for size < len(s) { + s, chunks = s[size:], append(chunks, s[:size]) + } + return append(chunks, s) +} + +// loadClientLimits maps each observed email to its clients.limit_ip in a few +// chunked queries, replacing the per-email settings-JSON parse that previously +// resolved the limit. +func (j *CheckClientIpJob) loadClientLimits(emails []string) map[string]int { + db := database.GetDB() + out := make(map[string]int, len(emails)) + for _, batch := range chunkEmails(emails, ipScanChunk) { + var rows []struct { + Email string + LimitIp int + } + if err := db.Model(&model.ClientRecord{}). + Select("email, limit_ip"). + Where("email IN ?", batch). + Scan(&rows).Error; err != nil { + j.checkError(err) + continue + } + for _, r := range rows { + out[r.Email] = r.LimitIp + } + } + return out +} + +// loadInboundsByEmails resolves each email's owning inbound through the +// clients/client_inbounds relation in chunked queries. Like the old per-email +// First() it keeps the lowest inbound id when a client spans several inbounds. +func (j *CheckClientIpJob) loadInboundsByEmails(emails []string) map[string]*model.Inbound { + db := database.GetDB() + minInboundByEmail := make(map[string]int, len(emails)) + for _, batch := range chunkEmails(emails, ipScanChunk) { + var pairs []struct { + Email string + InboundId int + } + if err := db.Table("client_inbounds"). + Select("clients.email AS email, client_inbounds.inbound_id AS inbound_id"). + Joins("JOIN clients ON clients.id = client_inbounds.client_id"). + Where("clients.email IN ?", batch). + Scan(&pairs).Error; err != nil { + j.checkError(err) + return nil + } + for _, p := range pairs { + if cur, ok := minInboundByEmail[p.Email]; !ok || p.InboundId < cur { + minInboundByEmail[p.Email] = p.InboundId + } + } + } + if len(minInboundByEmail) == 0 { + return nil + } + + idSet := make(map[int]struct{}, len(minInboundByEmail)) + ids := make([]int, 0, len(minInboundByEmail)) + for _, id := range minInboundByEmail { + if _, seen := idSet[id]; !seen { + idSet[id] = struct{}{} + ids = append(ids, id) + } + } + sort.Ints(ids) + inboundsById := make(map[int]*model.Inbound, len(ids)) + for lo := 0; lo < len(ids); lo += ipScanChunk { + hi := min(lo+ipScanChunk, len(ids)) + var page []*model.Inbound + if err := db.Model(&model.Inbound{}).Where("id IN ?", ids[lo:hi]).Find(&page).Error; err != nil { + j.checkError(err) + return nil + } + for _, ib := range page { + inboundsById[ib.Id] = ib + } + } + + out := make(map[string]*model.Inbound, len(minInboundByEmail)) + for email, id := range minInboundByEmail { + if ib, ok := inboundsById[id]; ok { + out[email] = ib + } + } + return out +} + +func (j *CheckClientIpJob) loadClientIpRows(emails []string) map[string]*model.InboundClientIps { + db := database.GetDB() + out := make(map[string]*model.InboundClientIps, len(emails)) + for _, batch := range chunkEmails(emails, ipScanChunk) { + var rows []model.InboundClientIps + if err := db.Where("client_email IN ?", batch).Find(&rows).Error; err != nil { + j.checkError(err) + continue + } + for i := range rows { + out[rows[i].ClientEmail] = &rows[i] + } + } + return out +} + // processObserved runs collection + enforcement for one scan's observations // (email -> ip -> last-seen unix seconds). observedAreLive marks the // observations as live connections, which bypass the stale cutoff: a connection // that opened hours ago is still live even though its timestamp is old. The // online-stats API always reports live connections, so the job passes true. +// Lookups are batched up front and all inbound_client_ips writes share one +// transaction, so a scan costs a handful of queries and one fsync instead of +// several per observed email. func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, enforce, observedAreLive bool) bool { shouldCleanLog := false now := time.Now().Unix() + + emails := make([]string, 0, len(observed)) + for email := range observed { + emails = append(emails, email) + } + sort.Strings(emails) + + limitByEmail := j.loadClientLimits(emails) + inboundByEmail := j.loadInboundsByEmails(emails) + ipRowByEmail := j.loadClientIpRows(emails) + // attribution accumulates this scan's local observations per email so they can // be recorded under this panel's own guid for cross-node IP attribution. attribution := make(map[string][]model.ClientIpEntry, len(observed)) - for email, ipTimestamps := range observed { + + type pendingDisconnect struct { + inbound *model.Inbound + email string + } + var disconnects []pendingDisconnect + + db := database.GetDB() + tx := db.Begin() + if tx.Error != nil { + j.checkError(tx.Error) + return false + } + committed := false + defer func() { + if !committed { + tx.Rollback() + } + }() + + for _, email := range emails { + ipTimestamps := observed[email] // The observations can still reference a client that was just renamed // or deleted; its email no longer matches any inbound. Skip it (and // drop any orphaned tracking row) instead of recreating a row and - // logging an ERROR every run (#4963). - inbound, err := j.getInboundByEmail(email) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email) - j.delInboundClientIps(email) - } else { - j.checkError(err) + // logging an ERROR every run (#4963). The batch map resolves through + // the clients relation; the per-email fallback keeps its settings LIKE + // net for clients not yet present there. + inbound, ok := inboundByEmail[email] + if !ok { + var err error + inbound, err = j.getInboundByEmail(email) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + logger.Debugf("[LimitIP] skipping stale observed email %q (renamed or deleted)", email) + j.delInboundClientIps(tx, email) + } else { + j.checkError(err) + } + continue } - continue } // Convert to IPWithTimestamp slice @@ -170,13 +325,44 @@ func (j *CheckClientIpJob) processObserved(observed map[string]map[string]int64, attribution[email] = attrEntries } - clientIpsRecord, err := j.getInboundClientIps(email) - if err != nil { - _ = j.addInboundClientIps(email, ipsWithTime) + clientIpsRecord, ok := ipRowByEmail[email] + if !ok { + jsonIps, err := json.Marshal(ipsWithTime) + if err != nil { + j.checkError(err) + continue + } + if err := tx.Save(&model.InboundClientIps{ClientEmail: email, Ips: string(jsonIps)}).Error; err != nil { + j.checkError(err) + } continue } - shouldCleanLog = j.updateInboundClientIps(clientIpsRecord, inbound, email, ipsWithTime, enforce, observedAreLive) || shouldCleanLog + cleaned, banned := j.updateInboundClientIps(tx, clientIpsRecord, inbound, email, limitByEmail[email], ipsWithTime, enforce, observedAreLive) + shouldCleanLog = cleaned || shouldCleanLog + if banned { + disconnects = append(disconnects, pendingDisconnect{inbound: inbound, email: email}) + } + } + + if err := tx.Commit().Error; err != nil { + j.checkError(err) + return shouldCleanLog + } + committed = true + + // Xray disconnects run after the commit so their network round-trips never + // extend the scan's write transaction (node syncs upsert the same table). + clientsCache := make(map[int][]model.Client) + for _, d := range disconnects { + clients, cached := clientsCache[d.inbound.Id] + if !cached { + settings := map[string][]model.Client{} + _ = json.Unmarshal([]byte(d.inbound.Settings), &settings) + clients = settings["clients"] + clientsCache[d.inbound.Id] = clients + } + j.disconnectClientTemporarily(d.inbound, d.email, clients) } j.recordLocalAttribution(attribution) @@ -275,81 +461,34 @@ func (j *CheckClientIpJob) checkError(e error) { } } -func (j *CheckClientIpJob) getInboundClientIps(clientEmail string) (*model.InboundClientIps, error) { - db := database.GetDB() - InboundClientIps := &model.InboundClientIps{} - err := db.Model(model.InboundClientIps{}).Where("client_email = ?", clientEmail).First(InboundClientIps).Error - if err != nil { - return nil, err - } - return InboundClientIps, nil -} - -func (j *CheckClientIpJob) addInboundClientIps(clientEmail string, ipsWithTime []IPWithTimestamp) error { - inboundClientIps := &model.InboundClientIps{} - jsonIps, err := json.Marshal(ipsWithTime) - j.checkError(err) - - inboundClientIps.ClientEmail = clientEmail - inboundClientIps.Ips = string(jsonIps) - - db := database.GetDB() - tx := db.Begin() - - defer func() { - if err == nil { - tx.Commit() - } else { - tx.Rollback() - } - }() - - err = tx.Save(inboundClientIps).Error - if err != nil { - return err - } - return nil -} - // delInboundClientIps drops the inbound_client_ips tracking row for an email // that no longer maps to any inbound (a renamed or deleted client), so stale // access-log entries don't keep a ghost row alive (#4963). -func (j *CheckClientIpJob) delInboundClientIps(clientEmail string) { - db := database.GetDB() - if err := db.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil { +func (j *CheckClientIpJob) delInboundClientIps(tx *gorm.DB, clientEmail string) { + if err := tx.Where("client_email = ?", clientEmail).Delete(&model.InboundClientIps{}).Error; err != nil { j.checkError(err) } } -func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) bool { +// updateInboundClientIps merges one email's observed IPs into its tracking row +// and applies the IP limit. limitIp comes from the caller (the clients table); +// writes go through the caller's transaction. banned=true asks the caller to +// disconnect the client after the transaction commits. +func (j *CheckClientIpJob) updateInboundClientIps(tx *gorm.DB, inboundClientIps *model.InboundClientIps, inbound *model.Inbound, clientEmail string, limitIp int, newIpsWithTime []IPWithTimestamp, enforce, observedAreLive bool) (shouldCleanLog, banned bool) { if inbound.Settings == "" { logger.Debug("wrong data:", inbound) - return false + return false, false } - settings := map[string][]model.Client{} - _ = json.Unmarshal([]byte(inbound.Settings), &settings) - clients := settings["clients"] - - // Find the client's IP limit - var limitIp int - var clientFound bool - for _, client := range clients { - if client.Email == clientEmail { - limitIp = client.LimitIP - clientFound = true - break - } - } - - if !enforce || !clientFound || limitIp <= 0 || !inbound.Enable { - // Nothing to enforce (collection-only run, no limit, client missing, or - // inbound disabled): record the observed IPs for the panel and return. + if !enforce || limitIp <= 0 || !inbound.Enable { + // Nothing to enforce (collection-only run, no limit on the clients row, + // or inbound disabled): record the observed IPs for the panel and return. jsonIps, _ := json.Marshal(newIpsWithTime) inboundClientIps.Ips = string(jsonIps) - db := database.GetDB() - db.Save(inboundClientIps) - return false + if err := tx.Save(inboundClientIps).Error; err != nil { + logger.Error("failed to save inboundClientIps:", err) + } + return false, false } // Parse old IPs from database @@ -368,18 +507,18 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun } liveIps, historicalIps := partitionLiveIps(ipMap, observedThisScan) - shouldCleanLog := false j.disAllowedIps = []string{} // historical db-only ips are excluded from this count on purpose. keptLive, bannedLive := selectIpsToBan(liveIps, limitIp) if len(bannedLive) > 0 { shouldCleanLog = true + banned = true logIpFile, err := os.OpenFile(xray.GetIPLimitLogPath(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644) if err != nil { logger.Errorf("failed to open IP limit log file: %s", err) - return false + return false, false } defer logIpFile.Close() ipLogger := log.New(logIpFile, "", log.LstdFlags) @@ -392,9 +531,6 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun j.disAllowedIps = append(j.disAllowedIps, ipTime.IP) ipLogger.Printf("[LIMIT_IP] Email = %s || Disconnecting OLD IP = %s || Timestamp = %d", clientEmail, ipTime.IP, ipTime.Timestamp) } - - // force xray to drop existing connections from banned ips - j.disconnectClientTemporarily(inbound, clientEmail, clients) } // keep kept-live + historical in the blob so the panel keeps showing @@ -406,18 +542,16 @@ func (j *CheckClientIpJob) updateInboundClientIps(inboundClientIps *model.Inboun jsonIps, _ := json.Marshal(dbIps) inboundClientIps.Ips = string(jsonIps) - db := database.GetDB() - err := db.Save(inboundClientIps).Error - if err != nil { + if err := tx.Save(inboundClientIps).Error; err != nil { logger.Error("failed to save inboundClientIps:", err) - return false + return false, banned } if len(j.disAllowedIps) > 0 { logger.Infof("[LIMIT_IP] Client %s: Kept %d live IPs, queued %d old IPs for fail2ban", clientEmail, len(keptLive), len(j.disAllowedIps)) } - return shouldCleanLog + return shouldCleanLog, banned } // disconnectClientTemporarily removes and re-adds a client to force disconnect banned connections diff --git a/internal/web/job/check_client_ip_job_integration_test.go b/internal/web/job/check_client_ip_job_integration_test.go index e28b9cc04..c0265ab87 100644 --- a/internal/web/job/check_client_ip_job_integration_test.go +++ b/internal/web/job/check_client_ip_job_integration_test.go @@ -95,7 +95,7 @@ func seedInboundOnlyWithClient(t *testing.T, tag, email string, limitIp int) *mo func seedLinkedInboundWithClient(t *testing.T, tag, email string, limitIp int) *model.Inbound { t.Helper() inbound := seedInboundOnlyWithClient(t, tag, email, limitIp) - client := &model.ClientRecord{Email: email} + client := &model.ClientRecord{Email: email, LimitIP: limitIp} if err := database.GetDB().Create(client).Error; err != nil { t.Fatalf("seed client record: %v", err) } @@ -206,11 +206,14 @@ func TestUpdateInboundClientIps_LiveIpNotBannedByStillFreshHistoricals(t *testin if err != nil { t.Fatalf("getInboundByEmail: %v", err) } - shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false) + shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 3, live, true, false) if shouldCleanLog { t.Fatalf("shouldCleanLog must be false, nothing should have been banned with 1 live ip under limit 3") } + if banned { + t.Fatalf("banned must be false with 1 live ip under limit 3") + } if len(j.disAllowedIps) != 0 { t.Fatalf("disAllowedIps must be empty, got %v", j.disAllowedIps) } @@ -259,11 +262,14 @@ func TestUpdateInboundClientIps_ExcessLiveIpIsStillBanned(t *testing.T) { if err != nil { t.Fatalf("getInboundByEmail: %v", err) } - shouldCleanLog := j.updateInboundClientIps(row, inbound, email, live, true, false) + shouldCleanLog, banned := j.updateInboundClientIps(database.GetDB(), row, inbound, email, 1, live, true, false) if !shouldCleanLog { t.Fatalf("shouldCleanLog must be true when the live set exceeds the limit") } + if !banned { + t.Fatalf("banned must be true when the live set exceeds the limit") + } if len(j.disAllowedIps) != 1 || j.disAllowedIps[0] != "10.1.0.1" { t.Fatalf("expected 10.1.0.1 to be banned; disAllowedIps = %v", j.disAllowedIps) }