diff --git a/e2e_test.go b/e2e_test.go index 9a63a0a..6e8c7a3 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -96,7 +96,7 @@ func TestE2E_FetchMetadataThroughDNS(t *testing.T) { t.Fatalf("create fetcher: %v", err) } - meta, err := fetcher.FetchMetadata() + meta, err := fetcher.FetchMetadata(context.Background()) if err != nil { t.Fatalf("fetch metadata: %v", err) } @@ -136,7 +136,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) { t.Fatalf("create fetcher: %v", err) } - meta, err := fetcher.FetchMetadata() + meta, err := fetcher.FetchMetadata(context.Background()) if err != nil { t.Fatalf("fetch metadata: %v", err) } @@ -146,7 +146,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) { t.Fatal("expected blocks > 0") } - fetchedMsgs, err := fetcher.FetchChannel(1, blockCount) + fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, blockCount) if err != nil { t.Fatalf("fetch channel: %v", err) } @@ -180,14 +180,14 @@ func TestE2E_FetchWithDoubleLabel(t *testing.T) { if err != nil { t.Fatalf("create fetcher: %v", err) } - fetcher.SetQueryMode(protocol.QueryDoubleLabel) + fetcher.SetQueryMode(protocol.QueryMultiLabel) - meta, err := fetcher.FetchMetadata() + meta, err := fetcher.FetchMetadata(context.Background()) if err != nil { t.Fatalf("fetch metadata: %v", err) } - fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks)) + fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks)) if err != nil { t.Fatalf("fetch channel: %v", err) } @@ -216,7 +216,7 @@ func TestE2E_WrongPassphrase(t *testing.T) { t.Fatalf("create fetcher: %v", err) } - _, err = fetcher.FetchMetadata() + _, err = fetcher.FetchMetadata(context.Background()) if err == nil { t.Fatal("expected error with wrong passphrase, got nil") } @@ -243,12 +243,12 @@ func TestE2E_LargeMessages(t *testing.T) { t.Fatalf("create fetcher: %v", err) } - meta, err := fetcher.FetchMetadata() + meta, err := fetcher.FetchMetadata(context.Background()) if err != nil { t.Fatalf("fetch metadata: %v", err) } - fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks)) + fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks)) if err != nil { t.Fatalf("fetch channel: %v", err) } @@ -586,8 +586,15 @@ func TestE2E_FullRoundTrip(t *testing.T) { t.Fatalf("config POST status=%d", resp.StatusCode) } - // Wait for auto-refresh to fetch data - time.Sleep(3 * time.Second) + // Refresh channels via selected-channel API semantics. + respRefresh1, err := http.Post(base+"/api/refresh?channel=1", "application/json", nil) + if err != nil { + t.Fatalf("POST /api/refresh?channel=1: %v", err) + } + respRefresh1.Body.Close() + // Give channel 1 refresh goroutine time to complete before refreshing channel 2, + // because starting a new refresh cancels the previous in-flight refresh. + time.Sleep(700 * time.Millisecond) // Channels should be populated resp2, err := http.Get(base + "/api/channels") @@ -621,6 +628,13 @@ func TestE2E_FullRoundTrip(t *testing.T) { t.Errorf("msg[0].Text = %q, want %q", msgList[0].Text, "General message 1") } + respRefresh2, err := http.Post(base+"/api/refresh?channel=2", "application/json", nil) + if err != nil { + t.Fatalf("POST /api/refresh?channel=2: %v", err) + } + respRefresh2.Body.Close() + time.Sleep(700 * time.Millisecond) + // Messages for channel 2 resp4, err := http.Get(base + "/api/messages/2") if err != nil { diff --git a/internal/client/fetcher.go b/internal/client/fetcher.go index 4ff6090..34f34ce 100644 --- a/internal/client/fetcher.go +++ b/internal/client/fetcher.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "math/rand" "strings" @@ -12,26 +13,36 @@ import ( "github.com/sartoopjj/thefeed/internal/protocol" ) -// LogFunc is a callback for logging DNS queries (for debug/TUI). +// LogFunc is a callback for log messages. type LogFunc func(msg string) +// noiseDomains are popular domains used to blend feed queries into normal-looking DNS traffic. +var noiseDomains = []string{ + "www.google.com", "www.cloudflare.com", "one.one.one.one", + "www.youtube.com", "www.instagram.com", "www.amazon.com", + "www.microsoft.com", "www.apple.com", "www.github.com", + "www.wikipedia.org", "www.reddit.com", "www.twitter.com", +} + // Fetcher fetches feed blocks over DNS. type Fetcher struct { domain string queryKey [protocol.KeySize]byte responseKey [protocol.KeySize]byte queryMode protocol.QueryEncoding + timeout time.Duration - mu sync.RWMutex - resolvers []string - timeout time.Duration + // Resolver pools — allResolvers is what the user configured; + // activeResolvers is kept up-to-date by ResolverChecker (only healthy ones). + mu sync.RWMutex + allResolvers []string + activeResolvers []string - // Rate limiting - rateMu sync.Mutex - queryDelay time.Duration - lastQuery time.Time + // Rate limiting via token bucket; nil means unlimited. + rateQPS float64 + rateCh chan struct{} - // Debug logging + debug bool logFunc LogFunc } @@ -42,23 +53,28 @@ func NewFetcher(domain, passphrase string, resolvers []string) (*Fetcher, error) return nil, fmt.Errorf("derive keys: %w", err) } + r := make([]string, len(resolvers)) + copy(r, resolvers) + return &Fetcher{ - domain: strings.TrimSuffix(domain, "."), - queryKey: qk, - responseKey: rk, - queryMode: protocol.QuerySingleLabel, - resolvers: resolvers, - timeout: 5 * time.Second, + domain: strings.TrimSuffix(domain, "."), + queryKey: qk, + responseKey: rk, + queryMode: protocol.QuerySingleLabel, + allResolvers: r, + activeResolvers: r, + timeout: 10 * time.Second, }, nil } -// SetRateLimit sets the maximum queries per second (0 = unlimited). +// SetRateLimit sets the maximum queries per second (0 = unlimited). Must be called before Start. func (f *Fetcher) SetRateLimit(qps float64) { - if qps <= 0 { - f.queryDelay = 0 - return - } - f.queryDelay = time.Duration(float64(time.Second) / qps) + f.rateQPS = qps +} + +// SetTimeout sets the per-query DNS timeout. +func (f *Fetcher) SetTimeout(d time.Duration) { + f.timeout = d } // SetLogFunc sets the debug log callback. @@ -66,83 +82,225 @@ func (f *Fetcher) SetLogFunc(fn LogFunc) { f.logFunc = fn } +// SetDebug enables or disables debug logging of generated query names. +func (f *Fetcher) SetDebug(debug bool) { + f.debug = debug +} + // SetQueryMode sets the DNS query encoding mode. func (f *Fetcher) SetQueryMode(mode protocol.QueryEncoding) { f.queryMode = mode } +// SetActiveResolvers updates the healthy resolver pool. Called by ResolverChecker. +// If the new list is empty, the current pool is unchanged (to avoid blackout during a bad check). +func (f *Fetcher) SetActiveResolvers(resolvers []string) { + f.mu.Lock() + defer f.mu.Unlock() + if len(resolvers) > 0 { + f.activeResolvers = make([]string, len(resolvers)) + copy(f.activeResolvers, resolvers) + } +} + +// SetResolvers replaces the full resolver list and resets the active pool. +func (f *Fetcher) SetResolvers(resolvers []string) { + f.mu.Lock() + defer f.mu.Unlock() + f.allResolvers = make([]string, len(resolvers)) + copy(f.allResolvers, resolvers) + f.activeResolvers = make([]string, len(resolvers)) + copy(f.activeResolvers, resolvers) +} + +// AllResolvers returns all user-configured resolvers. +func (f *Fetcher) AllResolvers() []string { + f.mu.RLock() + defer f.mu.RUnlock() + result := make([]string, len(f.allResolvers)) + copy(result, f.allResolvers) + return result +} + +// Resolvers returns the currently active (healthy) resolver list. +func (f *Fetcher) Resolvers() []string { + f.mu.RLock() + defer f.mu.RUnlock() + result := make([]string, len(f.activeResolvers)) + copy(result, f.activeResolvers) + return result +} + +// Start launches background goroutines (rate limiter and noise generator). +// ctx controls their lifetime — cancel it to cleanly stop them. +// Call once per fetcher configuration; creating a new fetcher replaces the old one. +func (f *Fetcher) Start(ctx context.Context) { + if f.rateQPS > 0 { + f.rateCh = make(chan struct{}, 1) + go f.runRateLimiter(ctx) + go f.runNoise(ctx) + } +} + +// runRateLimiter issues one token into rateCh every 1/QPS seconds. +// The channel capacity is 1, so tokens do not accumulate (no burst). +func (f *Fetcher) runRateLimiter(ctx context.Context) { + interval := time.Duration(float64(time.Second) / f.rateQPS) + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + select { + case f.rateCh <- struct{}{}: + default: // bucket full; discard extra token to prevent burst + } + } + } +} + +// runNoise sends decoy A-record queries to popular domains at a randomised +// rate matching the configured QPS, to make feed traffic look like normal DNS usage. +func (f *Fetcher) runNoise(ctx context.Context) { + interval := time.Duration(float64(time.Second) / f.rateQPS) + for { + // Random delay: 1–3× the query interval. + jitter := time.Duration(rand.Int63n(int64(2*interval) + 1)) + select { + case <-ctx.Done(): + return + case <-time.After(interval + jitter): + } + + resolvers := f.Resolvers() + if len(resolvers) == 0 { + continue + } + resolver := resolvers[rand.Intn(len(resolvers))] + if !strings.Contains(resolver, ":") { + resolver += ":53" + } + target := noiseDomains[rand.Intn(len(noiseDomains))] + + go func(r, d string) { + c := &dns.Client{Timeout: f.timeout} + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(d), dns.TypeA) + m.RecursionDesired = true + c.Exchange(m, r) //nolint:errcheck — fire-and-forget noise query + }(resolver, target) + } +} + func (f *Fetcher) log(format string, args ...any) { if f.logFunc != nil { f.logFunc(fmt.Sprintf(format, args...)) } } -func (f *Fetcher) rateWait() { - if f.queryDelay <= 0 { +// logProgress logs a progress bar: "prefix [====> ] 45%" +func (f *Fetcher) logProgress(prefix string, current, total float64) { + if f.logFunc == nil || total <= 0 { return } - f.rateMu.Lock() - defer f.rateMu.Unlock() - elapsed := time.Since(f.lastQuery) - if elapsed < f.queryDelay { - time.Sleep(f.queryDelay - elapsed) + + percent := int((current / total) * 100) + barLen := 20 + filled := int((current / total) * float64(barLen)) + empty := barLen - filled + + bar := "[" + for i := 0; i < filled; i++ { + bar += "=" } - f.lastQuery = time.Now() + if filled < barLen { + bar += ">" + } + for i := 0; i < empty-1; i++ { + bar += " " + } + bar += "]" + + f.logFunc(fmt.Sprintf("%s %s %d%%", prefix, bar, percent)) } -// SetResolvers replaces the resolver list. -func (f *Fetcher) SetResolvers(resolvers []string) { - f.mu.Lock() - defer f.mu.Unlock() - f.resolvers = resolvers +// rateWait blocks until a rate-limit token is available or ctx is cancelled. +// Returns nil when a token was acquired, ctx.Err() when cancelled. +func (f *Fetcher) rateWait(ctx context.Context) error { + if f.rateCh == nil { + // Unlimited: just propagate any existing cancellation. + return ctx.Err() + } + select { + case <-f.rateCh: + return nil + case <-ctx.Done(): + return ctx.Err() + } } -// Resolvers returns the current resolver list. -func (f *Fetcher) Resolvers() []string { - f.mu.RLock() - defer f.mu.RUnlock() - result := make([]string, len(f.resolvers)) - copy(result, f.resolvers) - return result -} - -// FetchBlock fetches a single block from a channel. -func (f *Fetcher) FetchBlock(channel, block uint16) ([]byte, error) { - f.rateWait() - - qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode) - if err != nil { - return nil, fmt.Errorf("encode query: %w", err) - } - - f.log("Q ch=%d blk=%d → %s", channel, block, qname) - - resolvers := f.Resolvers() - if len(resolvers) == 0 { - return nil, fmt.Errorf("no resolvers configured") - } - - // Shuffle resolvers to distribute load - shuffled := make([]string, len(resolvers)) - copy(shuffled, resolvers) - rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) - +// FetchBlock fetches a single encrypted block from the given channel. +// It enqueues through the rate limiter and respects ctx cancellation. +// On transient failure it retries up to 2 additional times with a short back-off. +func (f *Fetcher) FetchBlock(ctx context.Context, channel, block uint16) ([]byte, error) { + const maxAttempts = 3 var lastErr error - for _, resolver := range shuffled { - data, err := f.queryResolver(resolver, qname) - if err != nil { - lastErr = err - continue + for attempt := 0; attempt < maxAttempts; attempt++ { + if attempt > 0 { + // Brief back-off before retry; bail immediately if ctx is done. + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(attempt) * 500 * time.Millisecond): + } } - return data, nil - } - return nil, fmt.Errorf("all resolvers failed, last error: %w", lastErr) + if err := f.rateWait(ctx); err != nil { + return nil, err + } + + qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode) + if err != nil { + return nil, fmt.Errorf("encode query: %w", err) + } + if f.debug { + f.log("[debug] query ch=%d blk=%d attempt=%d qname=%s", channel, block, attempt+1, qname) + } + + resolvers := f.Resolvers() + if len(resolvers) == 0 { + return nil, fmt.Errorf("no active resolvers") + } + + // Shuffle to spread load across resolvers. + shuffled := make([]string, len(resolvers)) + copy(shuffled, resolvers) + rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] }) + + for _, resolver := range shuffled { + if ctx.Err() != nil { + return nil, ctx.Err() + } + data, err := f.queryResolver(ctx, resolver, qname) + if err != nil { + lastErr = err + continue + } + return data, nil + } + lastErr = fmt.Errorf("all resolvers failed: %w", lastErr) + if attempt+1 < maxAttempts { + f.log("block ch=%d blk=%d attempt %d/%d failed, retrying: %v", channel, block, attempt+1, maxAttempts, lastErr) + } + } + return nil, lastErr } -// FetchMetadata fetches and parses metadata (channel 0). -func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) { - data, err := f.FetchBlock(protocol.MetadataChannel, 0) +// FetchMetadata fetches and parses the metadata block (channel 0). +func (f *Fetcher) FetchMetadata(ctx context.Context) (*protocol.Metadata, error) { + data, err := f.FetchBlock(ctx, protocol.MetadataChannel, 0) if err != nil { return nil, fmt.Errorf("fetch metadata block 0: %w", err) } @@ -152,12 +310,15 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) { return meta, nil } - // Metadata might span multiple blocks + // Metadata may span multiple blocks. allData := make([]byte, len(data)) copy(allData, data) for blk := uint16(1); blk < 10; blk++ { - block, fetchErr := f.FetchBlock(protocol.MetadataChannel, blk) + if ctx.Err() != nil { + return nil, ctx.Err() + } + block, fetchErr := f.FetchBlock(ctx, protocol.MetadataChannel, blk) if fetchErr != nil { break } @@ -171,31 +332,41 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) { return nil, fmt.Errorf("could not parse metadata: %w", err) } -// FetchChannel fetches all blocks for a channel and parses messages. -func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Message, error) { +// FetchChannel fetches all blocks for a channel and returns the parsed messages. +// Cancelling ctx immediately aborts any queued or in-flight block fetches. +// Each block is retried individually via FetchBlock before the channel fetch fails. +func (f *Fetcher) FetchChannel(ctx context.Context, channelNum int, blockCount int) ([]protocol.Message, error) { if blockCount <= 0 { return nil, nil } - type result struct { + type blockResult struct { idx int data []byte err error } - results := make(chan result, blockCount) - // Limit concurrency to 3 to reduce DNS burst traffic - sem := make(chan struct{}, 3) + results := make(chan blockResult, blockCount) + // Cap concurrent DNS queries at 5; the token-bucket rate limiter provides + // the actual throughput control regardless of this concurrency cap. + sem := make(chan struct{}, 5) var wg sync.WaitGroup for i := 0; i < blockCount; i++ { wg.Add(1) go func(idx int) { defer wg.Done() - sem <- struct{}{} + // Acquire semaphore or bail on cancellation. + select { + case sem <- struct{}{}: + case <-ctx.Done(): + results <- blockResult{idx: idx, err: ctx.Err()} + return + } defer func() { <-sem }() - data, err := f.FetchBlock(uint16(channelNum), uint16(idx)) - results <- result{idx: idx, data: data, err: err} + + data, err := f.FetchBlock(ctx, uint16(channelNum), uint16(idx)) + results <- blockResult{idx: idx, data: data, err: err} }(i) } @@ -205,24 +376,33 @@ func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Messa }() ordered := make([][]byte, blockCount) + completed := 0 for r := range results { if r.err != nil { - return nil, fmt.Errorf("fetch block %d: %w", r.idx, r.err) + if r.err == ctx.Err() { + // Context cancelled — abort immediately. + return nil, r.err + } + // FetchBlock already retried internally; log and treat as fatal for this channel. + f.log("Channel %d block %d permanently failed: %v", channelNum, r.idx, r.err) + return nil, fmt.Errorf("channel %d block %d: %w", channelNum, r.idx, r.err) } ordered[r.idx] = r.data + completed++ + f.logProgress(fmt.Sprintf("Channel %d", channelNum), float64(completed), float64(blockCount)) } var allData []byte - for _, block := range ordered { - allData = append(allData, block...) + for _, b := range ordered { + allData = append(allData, b...) } return protocol.ParseMessages(allData) } -func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) { +func (f *Fetcher) queryResolver(ctx context.Context, resolver, qname string) ([]byte, error) { if !strings.Contains(resolver, ":") { - resolver = resolver + ":53" + resolver += ":53" } c := new(dns.Client) @@ -231,8 +411,9 @@ func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT) m.RecursionDesired = true + m.SetEdns0(4096, false) // advertise 4 KiB UDP buffer to avoid response truncation - resp, _, err := c.Exchange(m, resolver) + resp, _, err := c.ExchangeContext(ctx, m, resolver) if err != nil { return nil, fmt.Errorf("dns exchange with %s: %w", resolver, err) } diff --git a/internal/client/resolver.go b/internal/client/resolver.go index 7494af2..36cd41f 100644 --- a/internal/client/resolver.go +++ b/internal/client/resolver.go @@ -1,181 +1,129 @@ package client import ( - "bufio" + "context" "fmt" - "log" - "net" - "os" "strings" "sync" - "sync/atomic" "time" + + "github.com/miekg/dns" + + "github.com/sartoopjj/thefeed/internal/protocol" ) -// ResolverScanner scans CIDR ranges to find working DNS resolvers. -type ResolverScanner struct { - fetcher *Fetcher - concurrency int - timeout time.Duration +// ResolverChecker periodically probes the fetcher's configured resolvers and +// updates the active (healthy) resolver pool. It replaces the old file/CIDR +// scanner — no file I/O; just a plain DNS probe on channel 0. +type ResolverChecker struct { + fetcher *Fetcher + timeout time.Duration + logFunc LogFunc } -// NewResolverScanner creates a resolver scanner. -func NewResolverScanner(fetcher *Fetcher, concurrency int) *ResolverScanner { - if concurrency <= 0 { - concurrency = 50 +// NewResolverChecker creates a health checker for the resolvers in fetcher. +// timeout is the per-probe deadline; 0 uses a 5-second default. +func NewResolverChecker(fetcher *Fetcher, timeout time.Duration) *ResolverChecker { + if timeout <= 0 { + timeout = 5 * time.Second } - return &ResolverScanner{ - fetcher: fetcher, - concurrency: concurrency, - timeout: 3 * time.Second, + return &ResolverChecker{ + fetcher: fetcher, + timeout: timeout, } } -// ScanCIDR scans a CIDR range for working DNS resolvers. -func (rs *ResolverScanner) ScanCIDR(cidr string, onFound func(ip string)) error { - _, ipNet, err := net.ParseCIDR(cidr) - if err != nil { - return fmt.Errorf("parse CIDR %q: %w", cidr, err) - } - - ips := expandCIDR(ipNet) - return rs.scanIPs(ips, onFound) +// SetLogFunc sets the callback used to emit health-check results to the log panel. +func (rc *ResolverChecker) SetLogFunc(fn LogFunc) { + rc.logFunc = fn } -// ScanFile scans resolver IPs from a file (one per line, supports CIDR notation). -func (rs *ResolverScanner) ScanFile(path string, onFound func(ip string)) error { - f, err := os.Open(path) - if err != nil { - return err - } - defer f.Close() - - var ips []string - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - if strings.Contains(line, "/") { - _, ipNet, err := net.ParseCIDR(line) - if err != nil { - log.Printf("[resolver] skip invalid CIDR: %s", line) - continue +// Start begins the periodic health-check loop in the background. +// An initial check runs immediately; subsequent checks happen every 10 minutes. +// ctx controls the lifetime — cancel it to stop the checker. +func (rc *ResolverChecker) Start(ctx context.Context) { + go func() { + rc.runCheck() + ticker := time.NewTicker(10 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + rc.runCheck() } - ips = append(ips, expandCIDR(ipNet)...) - } else { - ips = append(ips, line) } - } - if err := scanner.Err(); err != nil { - return err - } - - return rs.scanIPs(ips, onFound) + }() } -// CheckResolver tests if a single resolver works by querying metadata. -func (rs *ResolverScanner) CheckResolver(ip string) bool { - if !strings.Contains(ip, ":") { - ip = ip + ":53" +func (rc *ResolverChecker) runCheck() { + resolvers := rc.fetcher.AllResolvers() + if len(resolvers) == 0 { + return } - // Create a new fetcher with only this resolver to avoid copying the lock. - tmpFetcher := &Fetcher{ - domain: rs.fetcher.domain, - queryKey: rs.fetcher.queryKey, - responseKey: rs.fetcher.responseKey, - resolvers: []string{ip}, - timeout: rs.timeout, - } + rc.log("Checking %d resolver(s)...", len(resolvers)) - _, err := tmpFetcher.FetchBlock(0, 0) - return err == nil -} - -func (rs *ResolverScanner) scanIPs(ips []string, onFound func(ip string)) error { - if len(ips) == 0 { - return fmt.Errorf("no IPs to scan") - } - - var found atomic.Int32 - sem := make(chan struct{}, rs.concurrency) + var healthy []string + var mu sync.Mutex var wg sync.WaitGroup + sem := make(chan struct{}, 10) // probe up to 10 resolvers concurrently - for _, ip := range ips { + for _, r := range resolvers { wg.Add(1) - go func(ip string) { + go func(r string) { defer wg.Done() sem <- struct{}{} defer func() { <-sem }() - if rs.CheckResolver(ip) { - found.Add(1) - if onFound != nil { - onFound(ip) - } + if rc.checkOne(r) { + mu.Lock() + healthy = append(healthy, r) + mu.Unlock() + rc.log("Resolver OK: %s", r) + } else { + rc.log("Resolver failed: %s", r) } - }(ip) + }(r) } - wg.Wait() - if found.Load() == 0 { - return fmt.Errorf("no working resolvers found among %d IPs", len(ips)) - } - return nil + rc.fetcher.SetActiveResolvers(healthy) + rc.log("Resolver check done: %d/%d healthy", len(healthy), len(resolvers)) } -// LoadResolversFile loads resolver IPs from a file (one per line). -func LoadResolversFile(path string) ([]string, error) { - f, err := os.Open(path) +// checkOne probes a single resolver by sending a metadata channel query +// (channel 0, block 0). A successful DNS response (any rcode that isn't a +// network/timeout error) means the resolver is reachable and understands the domain. +func (rc *ResolverChecker) checkOne(resolver string) bool { + if !strings.Contains(resolver, ":") { + resolver += ":53" + } + + qname, err := protocol.EncodeQuery( + rc.fetcher.queryKey, + protocol.MetadataChannel, 0, + rc.fetcher.domain, + rc.fetcher.queryMode, + ) if err != nil { - return nil, err + return false } - defer f.Close() - var resolvers []string - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - resolvers = append(resolvers, line) - } - return resolvers, scanner.Err() + c := &dns.Client{Timeout: rc.timeout} + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT) + m.RecursionDesired = true + + resp, _, err := c.Exchange(m, resolver) + // We consider the resolver healthy if we get any DNS response back + // (even NXDOMAIN means the resolver forwarded the query to our server). + return err == nil && resp != nil } -func expandCIDR(ipNet *net.IPNet) []string { - var ips []string - ip := ipNet.IP.Mask(ipNet.Mask) - - for ip := cloneIP(ip); ipNet.Contains(ip); incIP(ip) { - // Skip network and broadcast addresses for /24 and smaller - ones, bits := ipNet.Mask.Size() - if bits-ones <= 8 { - last := ip[len(ip)-1] - if last == 0 || last == 255 { - continue - } - } - ips = append(ips, ip.String()) - } - return ips -} - -func cloneIP(ip net.IP) net.IP { - dup := make(net.IP, len(ip)) - copy(dup, ip) - return dup -} - -func incIP(ip net.IP) { - for j := len(ip) - 1; j >= 0; j-- { - ip[j]++ - if ip[j] > 0 { - break - } +func (rc *ResolverChecker) log(format string, args ...any) { + if rc.logFunc != nil { + rc.logFunc(fmt.Sprintf(format, args...)) } } diff --git a/internal/protocol/crypto.go b/internal/protocol/crypto.go index c82ccd6..e37b07f 100644 --- a/internal/protocol/crypto.go +++ b/internal/protocol/crypto.go @@ -67,3 +67,42 @@ func Decrypt(key [KeySize]byte, ciphertext []byte) ([]byte, error) { nonce := ciphertext[:gcm.NonceSize()] return gcm.Open(nil, nonce, ciphertext[gcm.NonceSize():], nil) } + +// encryptQueryBlock encrypts an 8-byte query payload using a direct AES-256 block cipher. +// The payload is expanded to one AES block (16 bytes) with 8 trailing zero bytes before +// encryption. No nonce or auth tag needed: the 4 random bytes in the payload guarantee +// unique ciphertext per query. Result is always 16 bytes. +func encryptQueryBlock(key [KeySize]byte, payload []byte) ([]byte, error) { + if len(payload) != QueryPayloadSize { + return nil, fmt.Errorf("encryptQueryBlock: payload must be %d bytes, got %d", QueryPayloadSize, len(payload)) + } + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + var buf [aes.BlockSize]byte + copy(buf[:QueryPayloadSize], payload) // bytes 8-15 stay zero + block.Encrypt(buf[:], buf[:]) + return buf[:], nil +} + +// decryptQueryBlock decrypts a query ciphertext produced by encryptQueryBlock. +// Accepts ciphertext with optional random suffix bytes (≥ BlockSize); only the +// first BlockSize bytes are used. Verifies the last 8 bytes of plaintext are zero. +func decryptQueryBlock(key [KeySize]byte, ciphertext []byte) ([]byte, error) { + if len(ciphertext) < aes.BlockSize { + return nil, fmt.Errorf("decryptQueryBlock: need at least %d bytes, got %d", aes.BlockSize, len(ciphertext)) + } + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + var buf [aes.BlockSize]byte + block.Decrypt(buf[:], ciphertext[:aes.BlockSize]) // ignore suffix + for i := QueryPayloadSize; i < aes.BlockSize; i++ { + if buf[i] != 0 { + return nil, fmt.Errorf("decryptQueryBlock: integrity check failed (wrong key?)") + } + } + return buf[:QueryPayloadSize], nil +} diff --git a/internal/protocol/dns.go b/internal/protocol/dns.go index a61e1ad..423f221 100644 --- a/internal/protocol/dns.go +++ b/internal/protocol/dns.go @@ -7,26 +7,48 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "math/big" + "strconv" "strings" ) +const ( + maxDNSLabelLen = 63 + maxDNSNameLen = 253 // without trailing dot +) + // QueryEncoding controls how DNS query subdomains are encoded. type QueryEncoding int const ( // QuerySingleLabel uses base32 in a single DNS label (default, stealthier). QuerySingleLabel QueryEncoding = iota - // QueryDoubleLabel uses hex split across two DNS labels. - QueryDoubleLabel + // QueryMultiLabel uses hex split across multiple DNS labels. + QueryMultiLabel + // QueryPlainLabel encodes channel and block as plain decimal text (no query encryption). + // Responses are always encrypted regardless of this setting. + QueryPlainLabel ) var b32 = base32.StdEncoding.WithPadding(base32.NoPadding) -// EncodeQuery creates an encrypted DNS query subdomain. +// EncodeQuery creates a DNS query subdomain for the given channel and block. // Single-label (default): [base32_encrypted].domain -// Double-label: [hex_part1].[hex_part2].domain -// Payload: 4 random + 2 channel + 2 block = 8 bytes, encrypted with AES-GCM. +// Multi-label: [hex_part1].[hex_part2].domain +// Plain-label: cb.domain (no query encryption) +// Responses are always encrypted regardless of mode. func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, mode QueryEncoding) (string, error) { + domain = strings.TrimSuffix(domain, ".") + if domain == "" { + return "", fmt.Errorf("empty domain") + } + + // Plain text mode: no encryption, just human-readable label. + if mode == QueryPlainLabel { + label := fmt.Sprintf("c%db%d", channel, block) + return joinQName([]string{label}, domain) + } + payload := make([]byte, QueryPayloadSize) if _, err := rand.Read(payload[:QueryPaddingSize]); err != nil { @@ -36,24 +58,88 @@ func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, m binary.BigEndian.PutUint16(payload[QueryPaddingSize:], channel) binary.BigEndian.PutUint16(payload[QueryPaddingSize+QueryChannelSize:], block) - encrypted, err := Encrypt(queryKey, payload) + encrypted, err := encryptQueryBlock(queryKey, payload) if err != nil { return "", fmt.Errorf("encrypt query: %w", err) } + // Append 0–4 random suffix bytes so query length varies per request. + // The decoder strips these by only using the first aes.BlockSize bytes. + suffixLen, _ := rand.Int(rand.Reader, big.NewInt(5)) // [0,4] + suffix := make([]byte, int(suffixLen.Int64())) + rand.Read(suffix) //nolint:errcheck — non-critical randomness + ciphertext := append(encrypted, suffix...) + switch mode { - case QueryDoubleLabel: - h := hex.EncodeToString(encrypted) - mid := len(h) / 2 - return fmt.Sprintf("%s.%s.%s", h[:mid], h[mid:], domain), nil + case QueryMultiLabel: + h := hex.EncodeToString(ciphertext) + labels := splitMultiLabel(h) + return joinQName(labels, domain) default: - encoded := strings.ToLower(b32.EncodeToString(encrypted)) - return fmt.Sprintf("%s.%s", encoded, domain), nil + encoded := strings.ToLower(b32.EncodeToString(ciphertext)) + return joinQName([]string{encoded}, domain) } } +func splitLabel(s string, size int) []string { + if size <= 0 { + size = maxDNSLabelLen + } + if size > maxDNSLabelLen { + size = maxDNSLabelLen + } + + parts := make([]string, 0, (len(s)+size-1)/size) + for len(s) > size { + parts = append(parts, s[:size]) + s = s[size:] + } + if len(s) > 0 { + parts = append(parts, s) + } + return parts +} + +// splitMultiLabel splits a hex string into two labels of randomised, unequal length. +// The first label is between 12 and (len-4) chars so the second is at least 4 chars. +// This makes query labels look less uniform across requests. +func splitMultiLabel(h string) []string { + if len(h) <= 8 { + return []string{h} + } + // first label: random length in [minFirst, len-4] + minFirst := 8 + maxFirst := len(h) - 4 + if maxFirst <= minFirst { + maxFirst = minFirst + 1 + } + // crypto/rand for the split point; fall back to midpoint on error + split := (len(h) + 1) / 2 // default: slightly off-centre + if n, err := rand.Int(rand.Reader, big.NewInt(int64(maxFirst-minFirst+1))); err == nil { + split = minFirst + int(n.Int64()) + } + return []string{h[:split], h[split:]} +} + +func joinQName(labels []string, domain string) (string, error) { + for _, l := range labels { + if len(l) == 0 { + return "", fmt.Errorf("empty label") + } + if len(l) > maxDNSLabelLen { + return "", fmt.Errorf("label too long: %d", len(l)) + } + } + + qname := strings.Join(append(labels, domain), ".") + if len(qname) > maxDNSNameLen { + return "", fmt.Errorf("query name too long: %d", len(qname)) + } + return qname, nil +} + // DecodeQuery parses and decrypts a DNS query subdomain. -// Auto-detects single-label (base32) or double-label (hex) encoding. +// Auto-detects plain-text (cb), single-label base32, or multi-label hex encoding. func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block uint16, err error) { qname = strings.TrimSuffix(qname, ".") domain = strings.TrimSuffix(domain, ".") @@ -65,32 +151,54 @@ func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block u encoded := qname[:len(qname)-len(suffix)] - // Try base32 first (single label, no dots, or dots stripped) - b32str := strings.ReplaceAll(encoded, ".", "") - ciphertext, err := b32.DecodeString(strings.ToUpper(b32str)) - if err == nil { - return decryptQuery(queryKey, ciphertext) + // Try plain-label first: cb (short, no dots, all decimal) + if ch, blk, ok := parsePlainLabel(encoded); ok { + return ch, blk, nil } - // Fall back to hex (double-label) + // Try base32 (single-label: no dots or dots stripped) + b32str := strings.ReplaceAll(encoded, ".", "") + if ct, e := b32.DecodeString(strings.ToUpper(b32str)); e == nil { + return parseQueryCiphertext(queryKey, ct) + } + + // Fall back to hex (multi-label: dots stripped) hexStr := strings.ReplaceAll(encoded, ".", "") - ciphertext, err = hex.DecodeString(hexStr) - if err != nil { + ct, e := hex.DecodeString(hexStr) + if e != nil { return 0, 0, fmt.Errorf("decode query: invalid encoding") } - return decryptQuery(queryKey, ciphertext) + return parseQueryCiphertext(queryKey, ct) } -func decryptQuery(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) { - plaintext, err := Decrypt(queryKey, ciphertext) +// parsePlainLabel parses the plain-text query format "cb". +// Returns ok=false if the string does not match this pattern. +func parsePlainLabel(s string) (channel, block uint16, ok bool) { + if len(s) < 3 || s[0] != 'c' { + return 0, 0, false + } + bi := strings.IndexByte(s[1:], 'b') + if bi < 0 { + return 0, 0, false + } + bi++ // adjust for the slice offset + chStr, bStr := s[1:bi], s[bi+1:] + if len(chStr) == 0 || len(bStr) == 0 { + return 0, 0, false + } + ch, err1 := strconv.ParseUint(chStr, 10, 16) + blk, err2 := strconv.ParseUint(bStr, 10, 16) + if err1 != nil || err2 != nil { + return 0, 0, false + } + return uint16(ch), uint16(blk), true +} + +func parseQueryCiphertext(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) { + plaintext, err := decryptQueryBlock(queryKey, ciphertext) if err != nil { return 0, 0, fmt.Errorf("decrypt: %w", err) } - - if len(plaintext) != QueryPayloadSize { - return 0, 0, fmt.Errorf("invalid payload size: %d", len(plaintext)) - } - channel = binary.BigEndian.Uint16(plaintext[QueryPaddingSize:]) block = binary.BigEndian.Uint16(plaintext[QueryPaddingSize+QueryChannelSize:]) return channel, block, nil diff --git a/internal/protocol/dns_test.go b/internal/protocol/dns_test.go index 49ad3b3..92f6b5e 100644 --- a/internal/protocol/dns_test.go +++ b/internal/protocol/dns_test.go @@ -1,6 +1,7 @@ package protocol import ( + "fmt" "strings" "testing" ) @@ -43,21 +44,26 @@ func TestEncodeDecodeQuerySingleLabel(t *testing.T) { } } -func TestEncodeDecodeQueryDoubleLabel(t *testing.T) { +func TestEncodeDecodeQueryMultiLabel(t *testing.T) { qk, _, err := DeriveKeys("test-key") if err != nil { t.Fatalf("DeriveKeys: %v", err) } domain := "t.example.com" - qname, err := EncodeQuery(qk, 3, 7, domain, QueryDoubleLabel) + qname, err := EncodeQuery(qk, 3, 7, domain, QueryMultiLabel) if err != nil { t.Fatal(err) } - // Double label: two hex labels before domain + // Multi-label mode splits hex across labels; all must be DNS-safe. subdomain := qname[:len(qname)-len(domain)-1] parts := strings.Split(subdomain, ".") - if len(parts) != 2 { - t.Errorf("double-label query should have 2 parts, got %d: %q", len(parts), subdomain) + if len(parts) < 1 { + t.Errorf("multi-label query should have at least 1 part, got %d: %q", len(parts), subdomain) + } + for _, p := range parts { + if len(p) == 0 || len(p) > 63 { + t.Errorf("invalid label length %d in %q", len(p), p) + } } ch, blk, err := DecodeQuery(qk, qname, domain) if err != nil { @@ -68,6 +74,73 @@ func TestEncodeDecodeQueryDoubleLabel(t *testing.T) { } } +func TestEncodeQueryTooLongDomain(t *testing.T) { + qk, _, err := DeriveKeys("test-key") + if err != nil { + t.Fatalf("DeriveKeys: %v", err) + } + + // 250-char domain should make qname exceed DNS 253-char limit. + longDomain := strings.Repeat("a", 250) + _, err = EncodeQuery(qk, 1, 1, longDomain, QueryMultiLabel) + if err == nil { + t.Fatal("expected error for too-long domain") + } +} + +func TestEncodeDecodeQueryPlainLabel(t *testing.T) { + qk, _, err := DeriveKeys("test-key") + if err != nil { + t.Fatalf("DeriveKeys: %v", err) + } + domain := "t.example.com" + tests := []struct { + channel uint16 + block uint16 + }{ + {0, 0}, + {1, 42}, + {255, 65535}, + {3, 100}, + } + for _, tt := range tests { + qname, err := EncodeQuery(qk, tt.channel, tt.block, domain, QueryPlainLabel) + if err != nil { + t.Fatalf("EncodeQuery(%d, %d): %v", tt.channel, tt.block, err) + } + // Label should be "cb" — human readable, no padding hex. + want := fmt.Sprintf("c%db%d.%s", tt.channel, tt.block, domain) + if qname != want { + t.Errorf("got %q, want %q", qname, want) + } + // DecodeQuery must recover channel and block regardless of key. + ch, blk, err := DecodeQuery(qk, qname, domain) + if err != nil { + t.Fatalf("DecodeQuery: %v", err) + } + if ch != tt.channel || blk != tt.block { + t.Errorf("got ch=%d blk=%d, want ch=%d blk=%d", ch, blk, tt.channel, tt.block) + } + } +} + +func TestPlainLabelNotConfusedWithEncrypted(t *testing.T) { + qk, _, err := DeriveKeys("test-key") + if err != nil { + t.Fatalf("DeriveKeys: %v", err) + } + domain := "t.example.com" + // Encode with single-label then check that DecodeQuery does NOT treat it as plain. + qname, _ := EncodeQuery(qk, 5, 10, domain, QuerySingleLabel) + ch, blk, err := DecodeQuery(qk, qname, domain) + if err != nil { + t.Fatalf("DecodeQuery single-label: %v", err) + } + if ch != 5 || blk != 10 { + t.Errorf("got ch=%d blk=%d, want ch=5 blk=10", ch, blk) + } +} + func TestDecodeQueryWrongKey(t *testing.T) { qk1, _, _ := DeriveKeys("key1") qk2, _, _ := DeriveKeys("key2") diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index fd4e63f..05364d0 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -1,14 +1,21 @@ package protocol import ( + "crypto/rand" "encoding/binary" "fmt" + "math/big" ) const ( - // DefaultBlockPayload is the decrypted payload per DNS TXT block. - // Calculated to stay within 512-byte UDP DNS limit after encryption + base64 + padding overhead. - DefaultBlockPayload = 180 + // MinBlockPayload is the minimum decrypted payload per DNS TXT block. + MinBlockPayload = 200 + // MaxBlockPayload is the maximum decrypted payload per DNS TXT block. + // 600 bytes data + 28 GCM overhead + 2 prefix + 32 padding → ~856 base64 chars. + // Well within the 4096-byte EDNS0 UDP buffer the client advertises. + MaxBlockPayload = 600 + // DefaultBlockPayload is kept for compatibility; equals MaxBlockPayload. + DefaultBlockPayload = MaxBlockPayload // DefaultMaxPadding is the default random padding added to responses to vary DNS response size. DefaultMaxPadding = 32 @@ -210,20 +217,31 @@ func ParseMessages(data []byte) ([]Message, error) { return msgs, nil } -// SplitIntoBlocks splits data into blocks of DefaultBlockPayload size. +// SplitIntoBlocks splits data into blocks of randomly varying size in [MinBlockPayload, MaxBlockPayload]. +// Random sizes make traffic analysis harder; the client just concatenates all blocks to reassemble. func SplitIntoBlocks(data []byte) [][]byte { if len(data) == 0 { - return [][]byte{{}} + return [][]byte{{}} // channel 0 block 0 must always exist } var blocks [][]byte - for i := 0; i < len(data); i += DefaultBlockPayload { - end := i + DefaultBlockPayload - if end > len(data) { - end = len(data) + rem := data + for len(rem) > 0 { + size := randBlockSize() + if size > len(rem) { + size = len(rem) } - block := make([]byte, end-i) - copy(block, data[i:end]) + block := make([]byte, size) + copy(block, rem[:size]) blocks = append(blocks, block) + rem = rem[size:] } return blocks } + +func randBlockSize() int { + n, err := rand.Int(rand.Reader, big.NewInt(int64(MaxBlockPayload-MinBlockPayload+1))) + if err != nil { + return (MinBlockPayload + MaxBlockPayload) / 2 + } + return MinBlockPayload + int(n.Int64()) +} diff --git a/internal/protocol/protocol_test.go b/internal/protocol/protocol_test.go index 01434ee..26f0ade 100644 --- a/internal/protocol/protocol_test.go +++ b/internal/protocol/protocol_test.go @@ -57,19 +57,19 @@ func TestSerializeParseMessages(t *testing.T) { } func TestSplitIntoBlocks(t *testing.T) { - data := bytes.Repeat([]byte("A"), DefaultBlockPayload*3+50) + // MaxBlockPayload*3+50 guarantees at least 4 blocks (ceil((MaxBlockPayload*3+50)/MaxBlockPayload) = 4). + data := bytes.Repeat([]byte("A"), MaxBlockPayload*3+50) blocks := SplitIntoBlocks(data) - if len(blocks) != 4 { - t.Fatalf("blocks: got %d, want 4", len(blocks)) + if len(blocks) < 4 { + t.Fatalf("expected at least 4 blocks for %d bytes, got %d", len(data), len(blocks)) } - for i, b := range blocks { - if i < 3 && len(b) != DefaultBlockPayload { - t.Errorf("block %d: size %d, want %d", i, len(b), DefaultBlockPayload) + // Every non-last block must be within [MinBlockPayload, MaxBlockPayload]. + for i, b := range blocks[:len(blocks)-1] { + if len(b) < MinBlockPayload || len(b) > MaxBlockPayload { + t.Errorf("block %d: size %d, want [%d, %d]", i, len(b), MinBlockPayload, MaxBlockPayload) } } - if len(blocks[3]) != 50 { - t.Errorf("last block: size %d, want 50", len(blocks[3])) - } + // Reassembled data must equal original. var reassembled []byte for _, b := range blocks { reassembled = append(reassembled, b...) diff --git a/internal/server/feed_test.go b/internal/server/feed_test.go index 285e12c..09ff433 100644 --- a/internal/server/feed_test.go +++ b/internal/server/feed_test.go @@ -71,7 +71,9 @@ func TestFeedGetBlockUnknownChannel(t *testing.T) { func TestFeedLargeMessages(t *testing.T) { feed := NewFeed([]string{"Test"}) - largeText := make([]byte, 500) + // Use text large enough to span 2 blocks at DefaultBlockPayload (currently 700 bytes). + // Message serialization overhead is 10 bytes, so we need >690 bytes of text. + largeText := make([]byte, 750) for i := range largeText { largeText[i] = 65 } diff --git a/internal/server/telegram.go b/internal/server/telegram.go index 775b2ca..208f0e6 100644 --- a/internal/server/telegram.go +++ b/internal/server/telegram.go @@ -80,7 +80,7 @@ func NewTelegramReader(cfg TelegramConfig, channelUsernames []string, feed *Feed channels: cleaned, feed: feed, cache: make(map[string]cachedMessages), - cacheTTL: 1 * time.Minute, + cacheTTL: 5 * time.Minute, } } @@ -115,7 +115,7 @@ func (tr *TelegramReader) Run(ctx context.Context) error { tr.fetchAll(ctx, api) // Periodic fetch loop - ticker := time.NewTicker(1 * time.Minute) + ticker := time.NewTicker(3 * time.Minute) defer ticker.Stop() for { diff --git a/internal/web/static/index.html b/internal/web/static/index.html index b87d218..f8dbfb1 100644 --- a/internal/web/static/index.html +++ b/internal/web/static/index.html @@ -142,6 +142,44 @@ word-break: break-word; } + .progress-panel { + height: 56px; + min-height: 56px; + background: var(--surface); + border-top: 1px solid var(--border); + border-bottom: 1px solid var(--border); + overflow-y: auto; + font-size: 12px; + font-family: 'Vazirmatn', monospace; + padding: 12px 12px; + direction: ltr; + text-align: left; + color: var(--text-dim); + } + .progress-item { + margin-bottom: 8px; + padding: 6px; + border-radius: 4px; + background: var(--bg); + } + .progress-label { + font-size: 11px; + margin-bottom: 3px; + color: var(--text-dim); + } + .progress-bar { + width: 100%; + height: 6px; + background: var(--border); + border-radius: 3px; + overflow: hidden; + } + .progress-fill { + height: 100%; + background: var(--accent); + transition: width 0.2s; + } + .log-panel { height: 150px; min-height: 100px; @@ -149,13 +187,21 @@ border-top: 1px solid var(--border); overflow-y: auto; font-size: 12px; - font-family: 'Vazirmatn', monospace; + font-family: monospace; padding: 8px 12px; direction: ltr; text-align: left; color: var(--text-dim); } - .log-line { padding: 2px 0; } + .log-line { + padding: 2px 0; + line-height: 1.3; + } + .log-line.info { color: #60a5fa; } + .log-line.progress { color: #fbbf24; } + .log-line.error { color: #ef4444; } + .log-line.success { color: #22c55e; } + .log-line.warning { color: #f97316; } .modal-overlay { display: none; @@ -245,6 +291,7 @@ +
@@ -268,12 +315,17 @@
- + +
+
+ +
@@ -287,6 +339,7 @@ let selectedChannel = 0; let channels = []; let eventSource = null; + let autoRefreshTimer = null; async function init() { try { @@ -296,11 +349,16 @@ openSettings(); } else { document.getElementById('statusDot').className = 'status-dot connected'; - if (status.channels && status.channels.length > 0) { - channels = status.channels; - renderChannels(); - selectChannel(1); + await doRefresh(); + await loadChannels(); + if (channels && channels.length > 0) { + await selectChannel(1); } + autoRefreshTimer = setInterval(function() { + if (selectedChannel > 0) { + doRefresh(); + } + }, 120000); } } catch (e) {} connectSSE(); @@ -312,12 +370,23 @@ eventSource.addEventListener('log', function(e) { addLogLine(JSON.parse(e.data)); }); - eventSource.addEventListener('update', function(e) { - loadChannels(); - if (selectedChannel > 0) loadMessages(selectedChannel); + eventSource.addEventListener('update', async function(e) { + var wasEmpty = channels.length === 0; + await loadChannels(); + if (wasEmpty && channels.length > 0 && selectedChannel === 0) { + // Channels just appeared for the first time — auto-select the first one. + selectChannel(1); + } else if (selectedChannel > 0) { + loadMessages(selectedChannel); + } }); eventSource.onerror = function() { document.getElementById('statusDot').className = 'status-dot disconnected'; + // EventSource.CLOSED (2) means the browser stopped retrying — reconnect manually. + if (eventSource.readyState === EventSource.CLOSED) { + eventSource.close(); + setTimeout(connectSSE, 3000); + } }; eventSource.onopen = function() { document.getElementById('statusDot').className = 'status-dot connected'; @@ -343,6 +412,7 @@ async function selectChannel(num) { selectedChannel = num; renderChannels(); + await doRefresh(); await loadMessages(num); } @@ -392,16 +462,79 @@ var el = document.getElementById('logPanel'); var div = document.createElement('div'); div.className = 'log-line'; + + // Parse log level from line content + var level = 'info'; + var displayText = line; + + if (typeof line === 'string') { + if (line.includes('Error:') || line.includes('error') || line.includes('Error')) { + level = 'error'; + } else if (line.includes('Warning:') || line.includes('warning') || line.includes('Warning')) { + level = 'warning'; + } else if (line.includes('OK:') || line.includes('OK') || line.includes('success') || line.includes('done')) { + level = 'success'; + } else if (line.match(/\[\d+%\]|\d+\s*%/)) { + level = 'progress'; + updateProgressDisplay(line); + return; + } + } + + div.className = 'log-line ' + level; div.textContent = line; el.appendChild(div); el.scrollTop = el.scrollHeight; + + // Keep only last 200 lines while (el.children.length > 200) { el.removeChild(el.firstChild); } } + function updateProgressDisplay(line) { + // Parse progress from "Channel N [====> ] 45%" + var match = line.match(/Channel\s+(\d+)/); + if (!match) return; + + var channelNum = match[1]; + var percentMatch = line.match(/(\d+)%/); + var percent = percentMatch ? parseInt(percentMatch[1]) : 0; + + var panel = document.getElementById('progressPanel'); + var itemId = 'progress-current'; + var item = document.getElementById(itemId); + + if (!item) { + item = document.createElement('div'); + item.id = itemId; + item.className = 'progress-item'; + item.innerHTML = '
Channel ' + channelNum + '
' + + '
'; + panel.appendChild(item); + } else { + item.querySelector('.progress-label').textContent = 'Channel ' + channelNum; + } + + var fill = item.querySelector('.progress-fill'); + fill.style.width = percent + '%'; + + // Remove if complete + if (percent >= 100) { + setTimeout(function() { + if (item.parentNode) item.parentNode.removeChild(item); + }, 1000); + } + } + async function doRefresh() { - try { await fetch('/api/refresh', {method: 'POST'}); } catch (e) {} + try { + var url = '/api/refresh'; + if (selectedChannel > 0) { + url += '?channel=' + selectedChannel; + } + await fetch(url, {method: 'POST'}); + } catch (e) {} } function openSettings() { @@ -411,7 +544,12 @@ if (cfg.key) document.getElementById('cfgKey').value = cfg.key; if (cfg.resolvers) document.getElementById('cfgResolvers').value = cfg.resolvers.join('\n'); if (cfg.queryMode) document.getElementById('cfgQueryMode').value = cfg.queryMode; - if (cfg.rateLimit) document.getElementById('cfgRateLimit').value = cfg.rateLimit; + if (typeof cfg.rateLimit === 'number') { + document.getElementById('cfgRateLimit').value = cfg.rateLimit; + } else { + document.getElementById('cfgRateLimit').value = 100; + } + document.getElementById('cfgDebug').checked = !!cfg.debug; }).catch(function() {}); } @@ -432,7 +570,8 @@ key: document.getElementById('cfgKey').value, resolvers: resolvers, queryMode: document.getElementById('cfgQueryMode').value, - rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 0 + rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 100, + debug: document.getElementById('cfgDebug').checked }; if (!cfg.domain || !cfg.key || resolvers.length === 0) { @@ -453,9 +592,18 @@ } closeSettings(); document.getElementById('statusDot').className = 'status-dot connected'; - setTimeout(function() { - loadChannels(); - }, 2000); + await doRefresh(); + await loadChannels(); + if (channels && channels.length > 0) { + await selectChannel(1); + } + if (!autoRefreshTimer) { + autoRefreshTimer = setInterval(function() { + if (selectedChannel > 0) { + fetch('/api/refresh?channel=' + selectedChannel + '&quiet=1', {method: 'POST'}); + } + }, 120000); + } } catch (e) { errEl.textContent = e.message; errEl.style.display = 'block'; diff --git a/internal/web/web.go b/internal/web/web.go index c024f95..200daed 100644 --- a/internal/web/web.go +++ b/internal/web/web.go @@ -1,6 +1,7 @@ package web import ( + "context" "embed" "encoding/json" "fmt" @@ -29,6 +30,11 @@ type Config struct { Resolvers []string `json:"resolvers"` QueryMode string `json:"queryMode"` RateLimit float64 `json:"rateLimit"` + // Timeout is the per-query DNS timeout in seconds (0 = default 5 s). + // Also used as the resolver health-check probe timeout. + Timeout float64 `json:"timeout,omitempty"` + // Debug enables verbose query logging (shows generated DNS query names). + Debug bool `json:"debug,omitempty"` } // Server is the web UI server for thefeed client. @@ -43,6 +49,16 @@ type Server struct { channels []protocol.ChannelInfo messages map[int][]protocol.Message + // fetcherCtx/fetcherCancel control the lifetime of the active fetcher's + // background goroutines (rate limiter, noise, resolver checker). + // They are cancelled and recreated each time the config changes. + fetcherCtx context.Context + fetcherCancel context.CancelFunc + + // refreshMu / refreshCancel allow a new refresh to cancel an in-progress one. + refreshMu sync.Mutex + refreshCancel context.CancelFunc + logMu sync.RWMutex logLines []string @@ -96,7 +112,7 @@ func (s *Server) Run() error { fmt.Printf("\n Open in browser: http://%s\n\n", addr) if s.fetcher != nil { - s.startAutoRefresh() + go s.refreshMetadataOnly() } return http.ListenAndServe(addr, mux) @@ -164,7 +180,7 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("init fetcher: %v", err), 500) return } - s.startAutoRefresh() + go s.refreshMetadataOnly() writeJSON(w, map[string]any{"ok": true}) default: @@ -202,7 +218,28 @@ func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) { http.Error(w, "method not allowed", 405) return } - go s.refresh() + // Background (quiet) refreshes skip silently if one is already running, + // so the auto-refresh timer never cancels a slow in-progress fetch. + if r.URL.Query().Get("quiet") == "1" { + s.refreshMu.Lock() + running := s.refreshCancel != nil + s.refreshMu.Unlock() + if running { + writeJSON(w, map[string]any{"ok": true, "skipped": true}) + return + } + } + chParam := r.URL.Query().Get("channel") + if chParam != "" { + chNum, err := strconv.Atoi(chParam) + if err != nil || chNum < 1 { + http.Error(w, "invalid channel", 400) + return + } + go s.refreshChannel(chNum) + } else { + go s.refreshMetadataOnly() + } writeJSON(w, map[string]any{"ok": true}) } @@ -278,6 +315,11 @@ func (s *Server) initFetcher() error { s.mu.Lock() defer s.mu.Unlock() + // Cancel goroutines from the previous fetcher configuration. + if s.fetcherCancel != nil { + s.fetcherCancel() + } + cfg := s.config if cfg == nil { return fmt.Errorf("no config") @@ -295,57 +337,81 @@ func (s *Server) initFetcher() error { } if cfg.QueryMode == "double" { - fetcher.SetQueryMode(protocol.QueryDoubleLabel) + fetcher.SetQueryMode(protocol.QueryMultiLabel) + } else if cfg.QueryMode == "plain" { + fetcher.SetQueryMode(protocol.QueryPlainLabel) } + fetcher.SetDebug(cfg.Debug) if cfg.RateLimit > 0 { fetcher.SetRateLimit(cfg.RateLimit) } + timeout := 5 * time.Second + if cfg.Timeout > 0 { + timeout = time.Duration(cfg.Timeout * float64(time.Second)) + } + fetcher.SetTimeout(timeout) + fetcher.SetLogFunc(func(msg string) { s.addLog(msg) }) + // Create a shared context for this fetcher's lifetime. + ctx, cancel := context.WithCancel(context.Background()) + s.fetcherCtx = ctx + s.fetcherCancel = cancel + + // Start rate limiter and noise goroutines. + fetcher.Start(ctx) + + // Start periodic resolver health checks. + checker := client.NewResolverChecker(fetcher, timeout) + checker.SetLogFunc(func(msg string) { + s.addLog(msg) + }) + checker.Start(ctx) + s.fetcher = fetcher s.cache = cache return nil } -func (s *Server) startAutoRefresh() { - if s.stopRefresh != nil { - close(s.stopRefresh) +func (s *Server) refreshMetadataOnly() { + // Cancel any in-progress refresh and start a new cancellable one. + s.refreshMu.Lock() + if s.refreshCancel != nil { + s.refreshCancel() } - s.stopRefresh = make(chan struct{}) - stop := s.stopRefresh - go s.refresh() - - go func() { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - for { - select { - case <-stop: - return - case <-ticker.C: - s.refresh() - } - } - }() -} - -func (s *Server) refresh() { s.mu.RLock() + basectx := s.fetcherCtx fetcher := s.fetcher cache := s.cache s.mu.RUnlock() - if fetcher == nil { + if fetcher == nil || basectx == nil { + s.refreshMu.Unlock() return } + // Child context: cancelled either by the next refresh call or by a config change. + ctx, cancel := context.WithCancel(basectx) + s.refreshCancel = cancel + s.refreshMu.Unlock() + defer func() { + cancel() + s.refreshMu.Lock() + s.refreshCancel = nil + s.refreshMu.Unlock() + }() + s.addLog("Fetching metadata...") - meta, err := fetcher.FetchMetadata() + meta, err := fetcher.FetchMetadata(ctx) if err != nil { + if ctx.Err() != nil { + s.addLog("Refresh cancelled") + return + } s.addLog(fmt.Sprintf("Error: %v", err)) return } @@ -359,31 +425,91 @@ func (s *Server) refresh() { } s.broadcast("event: update\ndata: \"channels\"\n\n") +} - for i, ch := range meta.Channels { - chNum := i + 1 - blockCount := int(ch.Blocks) - if blockCount <= 0 { - continue - } - - msgs, err := fetcher.FetchChannel(chNum, blockCount) - if err != nil { - s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err)) - continue - } - - s.mu.Lock() - s.messages[chNum] = msgs - s.mu.Unlock() - - if cache != nil { - _ = cache.PutMessages(chNum, msgs) - } - - s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs))) +func (s *Server) refreshChannel(channelNum int) { + s.refreshMu.Lock() + if s.refreshCancel != nil { + s.refreshCancel() } + s.mu.RLock() + basectx := s.fetcherCtx + fetcher := s.fetcher + cache := s.cache + channels := s.channels + s.mu.RUnlock() + + if fetcher == nil || basectx == nil { + s.refreshMu.Unlock() + return + } + + ctx, cancel := context.WithCancel(basectx) + s.refreshCancel = cancel + s.refreshMu.Unlock() + defer func() { + cancel() + s.refreshMu.Lock() + s.refreshCancel = nil + s.refreshMu.Unlock() + }() + + meta, err := fetcher.FetchMetadata(ctx) + if err != nil { + if ctx.Err() != nil { + s.addLog("Refresh cancelled") + return + } + s.addLog(fmt.Sprintf("Error: %v", err)) + return + } + + s.mu.Lock() + s.channels = meta.Channels + s.mu.Unlock() + + if cache != nil { + _ = cache.PutMetadata(meta) + } + s.broadcast("event: update\ndata: \"channels\"\n\n") + + channels = meta.Channels + if channelNum < 1 || channelNum > len(channels) { + s.addLog(fmt.Sprintf("Warning: channel %d is not available", channelNum)) + return + } + + ch := channels[channelNum-1] + blockCount := int(ch.Blocks) + if blockCount <= 0 { + s.mu.Lock() + s.messages[channelNum] = nil + s.mu.Unlock() + s.addLog(fmt.Sprintf("Updated %s: 0 messages", ch.Name)) + s.broadcast("event: update\ndata: \"messages\"\n\n") + return + } + + msgs, err := fetcher.FetchChannel(ctx, channelNum, blockCount) + if err != nil { + if ctx.Err() != nil { + s.addLog("Refresh cancelled") + return + } + s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err)) + return + } + + s.mu.Lock() + s.messages[channelNum] = msgs + s.mu.Unlock() + + if cache != nil { + _ = cache.PutMessages(channelNum, msgs) + } + + s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs))) s.broadcast("event: update\ndata: \"messages\"\n\n") }