mirror of
https://github.com/sartoopjj/thefeed.git
synced 2026-05-19 05:14:35 +03:00
feat: implement AES-256 block cipher for query encryption and decryption
This commit is contained in:
+25
-11
@@ -96,7 +96,7 @@ func TestE2E_FetchMetadataThroughDNS(t *testing.T) {
|
|||||||
t.Fatalf("create fetcher: %v", err)
|
t.Fatalf("create fetcher: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
meta, err := fetcher.FetchMetadata()
|
meta, err := fetcher.FetchMetadata(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch metadata: %v", err)
|
t.Fatalf("fetch metadata: %v", err)
|
||||||
}
|
}
|
||||||
@@ -136,7 +136,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) {
|
|||||||
t.Fatalf("create fetcher: %v", err)
|
t.Fatalf("create fetcher: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
meta, err := fetcher.FetchMetadata()
|
meta, err := fetcher.FetchMetadata(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch metadata: %v", err)
|
t.Fatalf("fetch metadata: %v", err)
|
||||||
}
|
}
|
||||||
@@ -146,7 +146,7 @@ func TestE2E_FetchChannelMessages(t *testing.T) {
|
|||||||
t.Fatal("expected blocks > 0")
|
t.Fatal("expected blocks > 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
fetchedMsgs, err := fetcher.FetchChannel(1, blockCount)
|
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, blockCount)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch channel: %v", err)
|
t.Fatalf("fetch channel: %v", err)
|
||||||
}
|
}
|
||||||
@@ -180,14 +180,14 @@ func TestE2E_FetchWithDoubleLabel(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("create fetcher: %v", err)
|
t.Fatalf("create fetcher: %v", err)
|
||||||
}
|
}
|
||||||
fetcher.SetQueryMode(protocol.QueryDoubleLabel)
|
fetcher.SetQueryMode(protocol.QueryMultiLabel)
|
||||||
|
|
||||||
meta, err := fetcher.FetchMetadata()
|
meta, err := fetcher.FetchMetadata(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch metadata: %v", err)
|
t.Fatalf("fetch metadata: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks))
|
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch channel: %v", err)
|
t.Fatalf("fetch channel: %v", err)
|
||||||
}
|
}
|
||||||
@@ -216,7 +216,7 @@ func TestE2E_WrongPassphrase(t *testing.T) {
|
|||||||
t.Fatalf("create fetcher: %v", err)
|
t.Fatalf("create fetcher: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = fetcher.FetchMetadata()
|
_, err = fetcher.FetchMetadata(context.Background())
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error with wrong passphrase, got nil")
|
t.Fatal("expected error with wrong passphrase, got nil")
|
||||||
}
|
}
|
||||||
@@ -243,12 +243,12 @@ func TestE2E_LargeMessages(t *testing.T) {
|
|||||||
t.Fatalf("create fetcher: %v", err)
|
t.Fatalf("create fetcher: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
meta, err := fetcher.FetchMetadata()
|
meta, err := fetcher.FetchMetadata(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch metadata: %v", err)
|
t.Fatalf("fetch metadata: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fetchedMsgs, err := fetcher.FetchChannel(1, int(meta.Channels[0].Blocks))
|
fetchedMsgs, err := fetcher.FetchChannel(context.Background(), 1, int(meta.Channels[0].Blocks))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("fetch channel: %v", err)
|
t.Fatalf("fetch channel: %v", err)
|
||||||
}
|
}
|
||||||
@@ -586,8 +586,15 @@ func TestE2E_FullRoundTrip(t *testing.T) {
|
|||||||
t.Fatalf("config POST status=%d", resp.StatusCode)
|
t.Fatalf("config POST status=%d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for auto-refresh to fetch data
|
// Refresh channels via selected-channel API semantics.
|
||||||
time.Sleep(3 * time.Second)
|
respRefresh1, err := http.Post(base+"/api/refresh?channel=1", "application/json", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("POST /api/refresh?channel=1: %v", err)
|
||||||
|
}
|
||||||
|
respRefresh1.Body.Close()
|
||||||
|
// Give channel 1 refresh goroutine time to complete before refreshing channel 2,
|
||||||
|
// because starting a new refresh cancels the previous in-flight refresh.
|
||||||
|
time.Sleep(700 * time.Millisecond)
|
||||||
|
|
||||||
// Channels should be populated
|
// Channels should be populated
|
||||||
resp2, err := http.Get(base + "/api/channels")
|
resp2, err := http.Get(base + "/api/channels")
|
||||||
@@ -621,6 +628,13 @@ func TestE2E_FullRoundTrip(t *testing.T) {
|
|||||||
t.Errorf("msg[0].Text = %q, want %q", msgList[0].Text, "General message 1")
|
t.Errorf("msg[0].Text = %q, want %q", msgList[0].Text, "General message 1")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
respRefresh2, err := http.Post(base+"/api/refresh?channel=2", "application/json", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("POST /api/refresh?channel=2: %v", err)
|
||||||
|
}
|
||||||
|
respRefresh2.Body.Close()
|
||||||
|
time.Sleep(700 * time.Millisecond)
|
||||||
|
|
||||||
// Messages for channel 2
|
// Messages for channel 2
|
||||||
resp4, err := http.Get(base + "/api/messages/2")
|
resp4, err := http.Get(base + "/api/messages/2")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
+273
-92
@@ -1,6 +1,7 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -12,26 +13,36 @@ import (
|
|||||||
"github.com/sartoopjj/thefeed/internal/protocol"
|
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LogFunc is a callback for logging DNS queries (for debug/TUI).
|
// LogFunc is a callback for log messages.
|
||||||
type LogFunc func(msg string)
|
type LogFunc func(msg string)
|
||||||
|
|
||||||
|
// noiseDomains are popular domains used to blend feed queries into normal-looking DNS traffic.
|
||||||
|
var noiseDomains = []string{
|
||||||
|
"www.google.com", "www.cloudflare.com", "one.one.one.one",
|
||||||
|
"www.youtube.com", "www.instagram.com", "www.amazon.com",
|
||||||
|
"www.microsoft.com", "www.apple.com", "www.github.com",
|
||||||
|
"www.wikipedia.org", "www.reddit.com", "www.twitter.com",
|
||||||
|
}
|
||||||
|
|
||||||
// Fetcher fetches feed blocks over DNS.
|
// Fetcher fetches feed blocks over DNS.
|
||||||
type Fetcher struct {
|
type Fetcher struct {
|
||||||
domain string
|
domain string
|
||||||
queryKey [protocol.KeySize]byte
|
queryKey [protocol.KeySize]byte
|
||||||
responseKey [protocol.KeySize]byte
|
responseKey [protocol.KeySize]byte
|
||||||
queryMode protocol.QueryEncoding
|
queryMode protocol.QueryEncoding
|
||||||
|
timeout time.Duration
|
||||||
|
|
||||||
mu sync.RWMutex
|
// Resolver pools — allResolvers is what the user configured;
|
||||||
resolvers []string
|
// activeResolvers is kept up-to-date by ResolverChecker (only healthy ones).
|
||||||
timeout time.Duration
|
mu sync.RWMutex
|
||||||
|
allResolvers []string
|
||||||
|
activeResolvers []string
|
||||||
|
|
||||||
// Rate limiting
|
// Rate limiting via token bucket; nil means unlimited.
|
||||||
rateMu sync.Mutex
|
rateQPS float64
|
||||||
queryDelay time.Duration
|
rateCh chan struct{}
|
||||||
lastQuery time.Time
|
|
||||||
|
|
||||||
// Debug logging
|
debug bool
|
||||||
logFunc LogFunc
|
logFunc LogFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,23 +53,28 @@ func NewFetcher(domain, passphrase string, resolvers []string) (*Fetcher, error)
|
|||||||
return nil, fmt.Errorf("derive keys: %w", err)
|
return nil, fmt.Errorf("derive keys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r := make([]string, len(resolvers))
|
||||||
|
copy(r, resolvers)
|
||||||
|
|
||||||
return &Fetcher{
|
return &Fetcher{
|
||||||
domain: strings.TrimSuffix(domain, "."),
|
domain: strings.TrimSuffix(domain, "."),
|
||||||
queryKey: qk,
|
queryKey: qk,
|
||||||
responseKey: rk,
|
responseKey: rk,
|
||||||
queryMode: protocol.QuerySingleLabel,
|
queryMode: protocol.QuerySingleLabel,
|
||||||
resolvers: resolvers,
|
allResolvers: r,
|
||||||
timeout: 5 * time.Second,
|
activeResolvers: r,
|
||||||
|
timeout: 10 * time.Second,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetRateLimit sets the maximum queries per second (0 = unlimited).
|
// SetRateLimit sets the maximum queries per second (0 = unlimited). Must be called before Start.
|
||||||
func (f *Fetcher) SetRateLimit(qps float64) {
|
func (f *Fetcher) SetRateLimit(qps float64) {
|
||||||
if qps <= 0 {
|
f.rateQPS = qps
|
||||||
f.queryDelay = 0
|
}
|
||||||
return
|
|
||||||
}
|
// SetTimeout sets the per-query DNS timeout.
|
||||||
f.queryDelay = time.Duration(float64(time.Second) / qps)
|
func (f *Fetcher) SetTimeout(d time.Duration) {
|
||||||
|
f.timeout = d
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetLogFunc sets the debug log callback.
|
// SetLogFunc sets the debug log callback.
|
||||||
@@ -66,83 +82,225 @@ func (f *Fetcher) SetLogFunc(fn LogFunc) {
|
|||||||
f.logFunc = fn
|
f.logFunc = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetDebug enables or disables debug logging of generated query names.
|
||||||
|
func (f *Fetcher) SetDebug(debug bool) {
|
||||||
|
f.debug = debug
|
||||||
|
}
|
||||||
|
|
||||||
// SetQueryMode sets the DNS query encoding mode.
|
// SetQueryMode sets the DNS query encoding mode.
|
||||||
func (f *Fetcher) SetQueryMode(mode protocol.QueryEncoding) {
|
func (f *Fetcher) SetQueryMode(mode protocol.QueryEncoding) {
|
||||||
f.queryMode = mode
|
f.queryMode = mode
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetActiveResolvers updates the healthy resolver pool. Called by ResolverChecker.
|
||||||
|
// If the new list is empty, the current pool is unchanged (to avoid blackout during a bad check).
|
||||||
|
func (f *Fetcher) SetActiveResolvers(resolvers []string) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
if len(resolvers) > 0 {
|
||||||
|
f.activeResolvers = make([]string, len(resolvers))
|
||||||
|
copy(f.activeResolvers, resolvers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetResolvers replaces the full resolver list and resets the active pool.
|
||||||
|
func (f *Fetcher) SetResolvers(resolvers []string) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.allResolvers = make([]string, len(resolvers))
|
||||||
|
copy(f.allResolvers, resolvers)
|
||||||
|
f.activeResolvers = make([]string, len(resolvers))
|
||||||
|
copy(f.activeResolvers, resolvers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllResolvers returns all user-configured resolvers.
|
||||||
|
func (f *Fetcher) AllResolvers() []string {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
result := make([]string, len(f.allResolvers))
|
||||||
|
copy(result, f.allResolvers)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolvers returns the currently active (healthy) resolver list.
|
||||||
|
func (f *Fetcher) Resolvers() []string {
|
||||||
|
f.mu.RLock()
|
||||||
|
defer f.mu.RUnlock()
|
||||||
|
result := make([]string, len(f.activeResolvers))
|
||||||
|
copy(result, f.activeResolvers)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start launches background goroutines (rate limiter and noise generator).
|
||||||
|
// ctx controls their lifetime — cancel it to cleanly stop them.
|
||||||
|
// Call once per fetcher configuration; creating a new fetcher replaces the old one.
|
||||||
|
func (f *Fetcher) Start(ctx context.Context) {
|
||||||
|
if f.rateQPS > 0 {
|
||||||
|
f.rateCh = make(chan struct{}, 1)
|
||||||
|
go f.runRateLimiter(ctx)
|
||||||
|
go f.runNoise(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runRateLimiter issues one token into rateCh every 1/QPS seconds.
|
||||||
|
// The channel capacity is 1, so tokens do not accumulate (no burst).
|
||||||
|
func (f *Fetcher) runRateLimiter(ctx context.Context) {
|
||||||
|
interval := time.Duration(float64(time.Second) / f.rateQPS)
|
||||||
|
ticker := time.NewTicker(interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
select {
|
||||||
|
case f.rateCh <- struct{}{}:
|
||||||
|
default: // bucket full; discard extra token to prevent burst
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runNoise sends decoy A-record queries to popular domains at a randomised
|
||||||
|
// rate matching the configured QPS, to make feed traffic look like normal DNS usage.
|
||||||
|
func (f *Fetcher) runNoise(ctx context.Context) {
|
||||||
|
interval := time.Duration(float64(time.Second) / f.rateQPS)
|
||||||
|
for {
|
||||||
|
// Random delay: 1–3× the query interval.
|
||||||
|
jitter := time.Duration(rand.Int63n(int64(2*interval) + 1))
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(interval + jitter):
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvers := f.Resolvers()
|
||||||
|
if len(resolvers) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
resolver := resolvers[rand.Intn(len(resolvers))]
|
||||||
|
if !strings.Contains(resolver, ":") {
|
||||||
|
resolver += ":53"
|
||||||
|
}
|
||||||
|
target := noiseDomains[rand.Intn(len(noiseDomains))]
|
||||||
|
|
||||||
|
go func(r, d string) {
|
||||||
|
c := &dns.Client{Timeout: f.timeout}
|
||||||
|
m := new(dns.Msg)
|
||||||
|
m.SetQuestion(dns.Fqdn(d), dns.TypeA)
|
||||||
|
m.RecursionDesired = true
|
||||||
|
c.Exchange(m, r) //nolint:errcheck — fire-and-forget noise query
|
||||||
|
}(resolver, target)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (f *Fetcher) log(format string, args ...any) {
|
func (f *Fetcher) log(format string, args ...any) {
|
||||||
if f.logFunc != nil {
|
if f.logFunc != nil {
|
||||||
f.logFunc(fmt.Sprintf(format, args...))
|
f.logFunc(fmt.Sprintf(format, args...))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Fetcher) rateWait() {
|
// logProgress logs a progress bar: "prefix [====> ] 45%"
|
||||||
if f.queryDelay <= 0 {
|
func (f *Fetcher) logProgress(prefix string, current, total float64) {
|
||||||
|
if f.logFunc == nil || total <= 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.rateMu.Lock()
|
|
||||||
defer f.rateMu.Unlock()
|
percent := int((current / total) * 100)
|
||||||
elapsed := time.Since(f.lastQuery)
|
barLen := 20
|
||||||
if elapsed < f.queryDelay {
|
filled := int((current / total) * float64(barLen))
|
||||||
time.Sleep(f.queryDelay - elapsed)
|
empty := barLen - filled
|
||||||
|
|
||||||
|
bar := "["
|
||||||
|
for i := 0; i < filled; i++ {
|
||||||
|
bar += "="
|
||||||
}
|
}
|
||||||
f.lastQuery = time.Now()
|
if filled < barLen {
|
||||||
|
bar += ">"
|
||||||
|
}
|
||||||
|
for i := 0; i < empty-1; i++ {
|
||||||
|
bar += " "
|
||||||
|
}
|
||||||
|
bar += "]"
|
||||||
|
|
||||||
|
f.logFunc(fmt.Sprintf("%s %s %d%%", prefix, bar, percent))
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetResolvers replaces the resolver list.
|
// rateWait blocks until a rate-limit token is available or ctx is cancelled.
|
||||||
func (f *Fetcher) SetResolvers(resolvers []string) {
|
// Returns nil when a token was acquired, ctx.Err() when cancelled.
|
||||||
f.mu.Lock()
|
func (f *Fetcher) rateWait(ctx context.Context) error {
|
||||||
defer f.mu.Unlock()
|
if f.rateCh == nil {
|
||||||
f.resolvers = resolvers
|
// Unlimited: just propagate any existing cancellation.
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-f.rateCh:
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolvers returns the current resolver list.
|
// FetchBlock fetches a single encrypted block from the given channel.
|
||||||
func (f *Fetcher) Resolvers() []string {
|
// It enqueues through the rate limiter and respects ctx cancellation.
|
||||||
f.mu.RLock()
|
// On transient failure it retries up to 2 additional times with a short back-off.
|
||||||
defer f.mu.RUnlock()
|
func (f *Fetcher) FetchBlock(ctx context.Context, channel, block uint16) ([]byte, error) {
|
||||||
result := make([]string, len(f.resolvers))
|
const maxAttempts = 3
|
||||||
copy(result, f.resolvers)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// FetchBlock fetches a single block from a channel.
|
|
||||||
func (f *Fetcher) FetchBlock(channel, block uint16) ([]byte, error) {
|
|
||||||
f.rateWait()
|
|
||||||
|
|
||||||
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("encode query: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
f.log("Q ch=%d blk=%d → %s", channel, block, qname)
|
|
||||||
|
|
||||||
resolvers := f.Resolvers()
|
|
||||||
if len(resolvers) == 0 {
|
|
||||||
return nil, fmt.Errorf("no resolvers configured")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shuffle resolvers to distribute load
|
|
||||||
shuffled := make([]string, len(resolvers))
|
|
||||||
copy(shuffled, resolvers)
|
|
||||||
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
|
|
||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for _, resolver := range shuffled {
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||||
data, err := f.queryResolver(resolver, qname)
|
if attempt > 0 {
|
||||||
if err != nil {
|
// Brief back-off before retry; bail immediately if ctx is done.
|
||||||
lastErr = err
|
select {
|
||||||
continue
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(time.Duration(attempt) * 500 * time.Millisecond):
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("all resolvers failed, last error: %w", lastErr)
|
if err := f.rateWait(ctx); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
qname, err := protocol.EncodeQuery(f.queryKey, channel, block, f.domain, f.queryMode)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("encode query: %w", err)
|
||||||
|
}
|
||||||
|
if f.debug {
|
||||||
|
f.log("[debug] query ch=%d blk=%d attempt=%d qname=%s", channel, block, attempt+1, qname)
|
||||||
|
}
|
||||||
|
|
||||||
|
resolvers := f.Resolvers()
|
||||||
|
if len(resolvers) == 0 {
|
||||||
|
return nil, fmt.Errorf("no active resolvers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffle to spread load across resolvers.
|
||||||
|
shuffled := make([]string, len(resolvers))
|
||||||
|
copy(shuffled, resolvers)
|
||||||
|
rand.Shuffle(len(shuffled), func(i, j int) { shuffled[i], shuffled[j] = shuffled[j], shuffled[i] })
|
||||||
|
|
||||||
|
for _, resolver := range shuffled {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
data, err := f.queryResolver(ctx, resolver, qname)
|
||||||
|
if err != nil {
|
||||||
|
lastErr = err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
lastErr = fmt.Errorf("all resolvers failed: %w", lastErr)
|
||||||
|
if attempt+1 < maxAttempts {
|
||||||
|
f.log("block ch=%d blk=%d attempt %d/%d failed, retrying: %v", channel, block, attempt+1, maxAttempts, lastErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, lastErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchMetadata fetches and parses metadata (channel 0).
|
// FetchMetadata fetches and parses the metadata block (channel 0).
|
||||||
func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
func (f *Fetcher) FetchMetadata(ctx context.Context) (*protocol.Metadata, error) {
|
||||||
data, err := f.FetchBlock(protocol.MetadataChannel, 0)
|
data, err := f.FetchBlock(ctx, protocol.MetadataChannel, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("fetch metadata block 0: %w", err)
|
return nil, fmt.Errorf("fetch metadata block 0: %w", err)
|
||||||
}
|
}
|
||||||
@@ -152,12 +310,15 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
|||||||
return meta, nil
|
return meta, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Metadata might span multiple blocks
|
// Metadata may span multiple blocks.
|
||||||
allData := make([]byte, len(data))
|
allData := make([]byte, len(data))
|
||||||
copy(allData, data)
|
copy(allData, data)
|
||||||
|
|
||||||
for blk := uint16(1); blk < 10; blk++ {
|
for blk := uint16(1); blk < 10; blk++ {
|
||||||
block, fetchErr := f.FetchBlock(protocol.MetadataChannel, blk)
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
block, fetchErr := f.FetchBlock(ctx, protocol.MetadataChannel, blk)
|
||||||
if fetchErr != nil {
|
if fetchErr != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -171,31 +332,41 @@ func (f *Fetcher) FetchMetadata() (*protocol.Metadata, error) {
|
|||||||
return nil, fmt.Errorf("could not parse metadata: %w", err)
|
return nil, fmt.Errorf("could not parse metadata: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchChannel fetches all blocks for a channel and parses messages.
|
// FetchChannel fetches all blocks for a channel and returns the parsed messages.
|
||||||
func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Message, error) {
|
// Cancelling ctx immediately aborts any queued or in-flight block fetches.
|
||||||
|
// Each block is retried individually via FetchBlock before the channel fetch fails.
|
||||||
|
func (f *Fetcher) FetchChannel(ctx context.Context, channelNum int, blockCount int) ([]protocol.Message, error) {
|
||||||
if blockCount <= 0 {
|
if blockCount <= 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type result struct {
|
type blockResult struct {
|
||||||
idx int
|
idx int
|
||||||
data []byte
|
data []byte
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
results := make(chan result, blockCount)
|
results := make(chan blockResult, blockCount)
|
||||||
// Limit concurrency to 3 to reduce DNS burst traffic
|
// Cap concurrent DNS queries at 5; the token-bucket rate limiter provides
|
||||||
sem := make(chan struct{}, 3)
|
// the actual throughput control regardless of this concurrency cap.
|
||||||
|
sem := make(chan struct{}, 5)
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
for i := 0; i < blockCount; i++ {
|
for i := 0; i < blockCount; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(idx int) {
|
go func(idx int) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
sem <- struct{}{}
|
// Acquire semaphore or bail on cancellation.
|
||||||
|
select {
|
||||||
|
case sem <- struct{}{}:
|
||||||
|
case <-ctx.Done():
|
||||||
|
results <- blockResult{idx: idx, err: ctx.Err()}
|
||||||
|
return
|
||||||
|
}
|
||||||
defer func() { <-sem }()
|
defer func() { <-sem }()
|
||||||
data, err := f.FetchBlock(uint16(channelNum), uint16(idx))
|
|
||||||
results <- result{idx: idx, data: data, err: err}
|
data, err := f.FetchBlock(ctx, uint16(channelNum), uint16(idx))
|
||||||
|
results <- blockResult{idx: idx, data: data, err: err}
|
||||||
}(i)
|
}(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,24 +376,33 @@ func (f *Fetcher) FetchChannel(channelNum int, blockCount int) ([]protocol.Messa
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
ordered := make([][]byte, blockCount)
|
ordered := make([][]byte, blockCount)
|
||||||
|
completed := 0
|
||||||
for r := range results {
|
for r := range results {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
return nil, fmt.Errorf("fetch block %d: %w", r.idx, r.err)
|
if r.err == ctx.Err() {
|
||||||
|
// Context cancelled — abort immediately.
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
// FetchBlock already retried internally; log and treat as fatal for this channel.
|
||||||
|
f.log("Channel %d block %d permanently failed: %v", channelNum, r.idx, r.err)
|
||||||
|
return nil, fmt.Errorf("channel %d block %d: %w", channelNum, r.idx, r.err)
|
||||||
}
|
}
|
||||||
ordered[r.idx] = r.data
|
ordered[r.idx] = r.data
|
||||||
|
completed++
|
||||||
|
f.logProgress(fmt.Sprintf("Channel %d", channelNum), float64(completed), float64(blockCount))
|
||||||
}
|
}
|
||||||
|
|
||||||
var allData []byte
|
var allData []byte
|
||||||
for _, block := range ordered {
|
for _, b := range ordered {
|
||||||
allData = append(allData, block...)
|
allData = append(allData, b...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return protocol.ParseMessages(allData)
|
return protocol.ParseMessages(allData)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
|
func (f *Fetcher) queryResolver(ctx context.Context, resolver, qname string) ([]byte, error) {
|
||||||
if !strings.Contains(resolver, ":") {
|
if !strings.Contains(resolver, ":") {
|
||||||
resolver = resolver + ":53"
|
resolver += ":53"
|
||||||
}
|
}
|
||||||
|
|
||||||
c := new(dns.Client)
|
c := new(dns.Client)
|
||||||
@@ -231,8 +411,9 @@ func (f *Fetcher) queryResolver(resolver, qname string) ([]byte, error) {
|
|||||||
m := new(dns.Msg)
|
m := new(dns.Msg)
|
||||||
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
|
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
|
||||||
m.RecursionDesired = true
|
m.RecursionDesired = true
|
||||||
|
m.SetEdns0(4096, false) // advertise 4 KiB UDP buffer to avoid response truncation
|
||||||
|
|
||||||
resp, _, err := c.Exchange(m, resolver)
|
resp, _, err := c.ExchangeContext(ctx, m, resolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dns exchange with %s: %w", resolver, err)
|
return nil, fmt.Errorf("dns exchange with %s: %w", resolver, err)
|
||||||
}
|
}
|
||||||
|
|||||||
+85
-137
@@ -1,181 +1,129 @@
|
|||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
|
"github.com/sartoopjj/thefeed/internal/protocol"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResolverScanner scans CIDR ranges to find working DNS resolvers.
|
// ResolverChecker periodically probes the fetcher's configured resolvers and
|
||||||
type ResolverScanner struct {
|
// updates the active (healthy) resolver pool. It replaces the old file/CIDR
|
||||||
fetcher *Fetcher
|
// scanner — no file I/O; just a plain DNS probe on channel 0.
|
||||||
concurrency int
|
type ResolverChecker struct {
|
||||||
timeout time.Duration
|
fetcher *Fetcher
|
||||||
|
timeout time.Duration
|
||||||
|
logFunc LogFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolverScanner creates a resolver scanner.
|
// NewResolverChecker creates a health checker for the resolvers in fetcher.
|
||||||
func NewResolverScanner(fetcher *Fetcher, concurrency int) *ResolverScanner {
|
// timeout is the per-probe deadline; 0 uses a 5-second default.
|
||||||
if concurrency <= 0 {
|
func NewResolverChecker(fetcher *Fetcher, timeout time.Duration) *ResolverChecker {
|
||||||
concurrency = 50
|
if timeout <= 0 {
|
||||||
|
timeout = 5 * time.Second
|
||||||
}
|
}
|
||||||
return &ResolverScanner{
|
return &ResolverChecker{
|
||||||
fetcher: fetcher,
|
fetcher: fetcher,
|
||||||
concurrency: concurrency,
|
timeout: timeout,
|
||||||
timeout: 3 * time.Second,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanCIDR scans a CIDR range for working DNS resolvers.
|
// SetLogFunc sets the callback used to emit health-check results to the log panel.
|
||||||
func (rs *ResolverScanner) ScanCIDR(cidr string, onFound func(ip string)) error {
|
func (rc *ResolverChecker) SetLogFunc(fn LogFunc) {
|
||||||
_, ipNet, err := net.ParseCIDR(cidr)
|
rc.logFunc = fn
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse CIDR %q: %w", cidr, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ips := expandCIDR(ipNet)
|
|
||||||
return rs.scanIPs(ips, onFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ScanFile scans resolver IPs from a file (one per line, supports CIDR notation).
|
// Start begins the periodic health-check loop in the background.
|
||||||
func (rs *ResolverScanner) ScanFile(path string, onFound func(ip string)) error {
|
// An initial check runs immediately; subsequent checks happen every 10 minutes.
|
||||||
f, err := os.Open(path)
|
// ctx controls the lifetime — cancel it to stop the checker.
|
||||||
if err != nil {
|
func (rc *ResolverChecker) Start(ctx context.Context) {
|
||||||
return err
|
go func() {
|
||||||
}
|
rc.runCheck()
|
||||||
defer f.Close()
|
ticker := time.NewTicker(10 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
var ips []string
|
for {
|
||||||
scanner := bufio.NewScanner(f)
|
select {
|
||||||
for scanner.Scan() {
|
case <-ctx.Done():
|
||||||
line := strings.TrimSpace(scanner.Text())
|
return
|
||||||
if line == "" || strings.HasPrefix(line, "#") {
|
case <-ticker.C:
|
||||||
continue
|
rc.runCheck()
|
||||||
}
|
|
||||||
if strings.Contains(line, "/") {
|
|
||||||
_, ipNet, err := net.ParseCIDR(line)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("[resolver] skip invalid CIDR: %s", line)
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
ips = append(ips, expandCIDR(ipNet)...)
|
|
||||||
} else {
|
|
||||||
ips = append(ips, line)
|
|
||||||
}
|
}
|
||||||
}
|
}()
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return rs.scanIPs(ips, onFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckResolver tests if a single resolver works by querying metadata.
|
func (rc *ResolverChecker) runCheck() {
|
||||||
func (rs *ResolverScanner) CheckResolver(ip string) bool {
|
resolvers := rc.fetcher.AllResolvers()
|
||||||
if !strings.Contains(ip, ":") {
|
if len(resolvers) == 0 {
|
||||||
ip = ip + ":53"
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new fetcher with only this resolver to avoid copying the lock.
|
rc.log("Checking %d resolver(s)...", len(resolvers))
|
||||||
tmpFetcher := &Fetcher{
|
|
||||||
domain: rs.fetcher.domain,
|
|
||||||
queryKey: rs.fetcher.queryKey,
|
|
||||||
responseKey: rs.fetcher.responseKey,
|
|
||||||
resolvers: []string{ip},
|
|
||||||
timeout: rs.timeout,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := tmpFetcher.FetchBlock(0, 0)
|
var healthy []string
|
||||||
return err == nil
|
var mu sync.Mutex
|
||||||
}
|
|
||||||
|
|
||||||
func (rs *ResolverScanner) scanIPs(ips []string, onFound func(ip string)) error {
|
|
||||||
if len(ips) == 0 {
|
|
||||||
return fmt.Errorf("no IPs to scan")
|
|
||||||
}
|
|
||||||
|
|
||||||
var found atomic.Int32
|
|
||||||
sem := make(chan struct{}, rs.concurrency)
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
sem := make(chan struct{}, 10) // probe up to 10 resolvers concurrently
|
||||||
|
|
||||||
for _, ip := range ips {
|
for _, r := range resolvers {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(ip string) {
|
go func(r string) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
sem <- struct{}{}
|
sem <- struct{}{}
|
||||||
defer func() { <-sem }()
|
defer func() { <-sem }()
|
||||||
|
|
||||||
if rs.CheckResolver(ip) {
|
if rc.checkOne(r) {
|
||||||
found.Add(1)
|
mu.Lock()
|
||||||
if onFound != nil {
|
healthy = append(healthy, r)
|
||||||
onFound(ip)
|
mu.Unlock()
|
||||||
}
|
rc.log("Resolver OK: %s", r)
|
||||||
|
} else {
|
||||||
|
rc.log("Resolver failed: %s", r)
|
||||||
}
|
}
|
||||||
}(ip)
|
}(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
if found.Load() == 0 {
|
rc.fetcher.SetActiveResolvers(healthy)
|
||||||
return fmt.Errorf("no working resolvers found among %d IPs", len(ips))
|
rc.log("Resolver check done: %d/%d healthy", len(healthy), len(resolvers))
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadResolversFile loads resolver IPs from a file (one per line).
|
// checkOne probes a single resolver by sending a metadata channel query
|
||||||
func LoadResolversFile(path string) ([]string, error) {
|
// (channel 0, block 0). A successful DNS response (any rcode that isn't a
|
||||||
f, err := os.Open(path)
|
// network/timeout error) means the resolver is reachable and understands the domain.
|
||||||
|
func (rc *ResolverChecker) checkOne(resolver string) bool {
|
||||||
|
if !strings.Contains(resolver, ":") {
|
||||||
|
resolver += ":53"
|
||||||
|
}
|
||||||
|
|
||||||
|
qname, err := protocol.EncodeQuery(
|
||||||
|
rc.fetcher.queryKey,
|
||||||
|
protocol.MetadataChannel, 0,
|
||||||
|
rc.fetcher.domain,
|
||||||
|
rc.fetcher.queryMode,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return false
|
||||||
}
|
}
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
var resolvers []string
|
c := &dns.Client{Timeout: rc.timeout}
|
||||||
scanner := bufio.NewScanner(f)
|
m := new(dns.Msg)
|
||||||
for scanner.Scan() {
|
m.SetQuestion(dns.Fqdn(qname), dns.TypeTXT)
|
||||||
line := strings.TrimSpace(scanner.Text())
|
m.RecursionDesired = true
|
||||||
if line == "" || strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
resp, _, err := c.Exchange(m, resolver)
|
||||||
}
|
// We consider the resolver healthy if we get any DNS response back
|
||||||
resolvers = append(resolvers, line)
|
// (even NXDOMAIN means the resolver forwarded the query to our server).
|
||||||
}
|
return err == nil && resp != nil
|
||||||
return resolvers, scanner.Err()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func expandCIDR(ipNet *net.IPNet) []string {
|
func (rc *ResolverChecker) log(format string, args ...any) {
|
||||||
var ips []string
|
if rc.logFunc != nil {
|
||||||
ip := ipNet.IP.Mask(ipNet.Mask)
|
rc.logFunc(fmt.Sprintf(format, args...))
|
||||||
|
|
||||||
for ip := cloneIP(ip); ipNet.Contains(ip); incIP(ip) {
|
|
||||||
// Skip network and broadcast addresses for /24 and smaller
|
|
||||||
ones, bits := ipNet.Mask.Size()
|
|
||||||
if bits-ones <= 8 {
|
|
||||||
last := ip[len(ip)-1]
|
|
||||||
if last == 0 || last == 255 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ips = append(ips, ip.String())
|
|
||||||
}
|
|
||||||
return ips
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneIP(ip net.IP) net.IP {
|
|
||||||
dup := make(net.IP, len(ip))
|
|
||||||
copy(dup, ip)
|
|
||||||
return dup
|
|
||||||
}
|
|
||||||
|
|
||||||
func incIP(ip net.IP) {
|
|
||||||
for j := len(ip) - 1; j >= 0; j-- {
|
|
||||||
ip[j]++
|
|
||||||
if ip[j] > 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,3 +67,42 @@ func Decrypt(key [KeySize]byte, ciphertext []byte) ([]byte, error) {
|
|||||||
nonce := ciphertext[:gcm.NonceSize()]
|
nonce := ciphertext[:gcm.NonceSize()]
|
||||||
return gcm.Open(nil, nonce, ciphertext[gcm.NonceSize():], nil)
|
return gcm.Open(nil, nonce, ciphertext[gcm.NonceSize():], nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// encryptQueryBlock encrypts an 8-byte query payload using a direct AES-256 block cipher.
|
||||||
|
// The payload is expanded to one AES block (16 bytes) with 8 trailing zero bytes before
|
||||||
|
// encryption. No nonce or auth tag needed: the 4 random bytes in the payload guarantee
|
||||||
|
// unique ciphertext per query. Result is always 16 bytes.
|
||||||
|
func encryptQueryBlock(key [KeySize]byte, payload []byte) ([]byte, error) {
|
||||||
|
if len(payload) != QueryPayloadSize {
|
||||||
|
return nil, fmt.Errorf("encryptQueryBlock: payload must be %d bytes, got %d", QueryPayloadSize, len(payload))
|
||||||
|
}
|
||||||
|
block, err := aes.NewCipher(key[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var buf [aes.BlockSize]byte
|
||||||
|
copy(buf[:QueryPayloadSize], payload) // bytes 8-15 stay zero
|
||||||
|
block.Encrypt(buf[:], buf[:])
|
||||||
|
return buf[:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decryptQueryBlock decrypts a query ciphertext produced by encryptQueryBlock.
|
||||||
|
// Accepts ciphertext with optional random suffix bytes (≥ BlockSize); only the
|
||||||
|
// first BlockSize bytes are used. Verifies the last 8 bytes of plaintext are zero.
|
||||||
|
func decryptQueryBlock(key [KeySize]byte, ciphertext []byte) ([]byte, error) {
|
||||||
|
if len(ciphertext) < aes.BlockSize {
|
||||||
|
return nil, fmt.Errorf("decryptQueryBlock: need at least %d bytes, got %d", aes.BlockSize, len(ciphertext))
|
||||||
|
}
|
||||||
|
block, err := aes.NewCipher(key[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var buf [aes.BlockSize]byte
|
||||||
|
block.Decrypt(buf[:], ciphertext[:aes.BlockSize]) // ignore suffix
|
||||||
|
for i := QueryPayloadSize; i < aes.BlockSize; i++ {
|
||||||
|
if buf[i] != 0 {
|
||||||
|
return nil, fmt.Errorf("decryptQueryBlock: integrity check failed (wrong key?)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return buf[:QueryPayloadSize], nil
|
||||||
|
}
|
||||||
|
|||||||
+137
-29
@@ -7,26 +7,48 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxDNSLabelLen = 63
|
||||||
|
maxDNSNameLen = 253 // without trailing dot
|
||||||
|
)
|
||||||
|
|
||||||
// QueryEncoding controls how DNS query subdomains are encoded.
|
// QueryEncoding controls how DNS query subdomains are encoded.
|
||||||
type QueryEncoding int
|
type QueryEncoding int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// QuerySingleLabel uses base32 in a single DNS label (default, stealthier).
|
// QuerySingleLabel uses base32 in a single DNS label (default, stealthier).
|
||||||
QuerySingleLabel QueryEncoding = iota
|
QuerySingleLabel QueryEncoding = iota
|
||||||
// QueryDoubleLabel uses hex split across two DNS labels.
|
// QueryMultiLabel uses hex split across multiple DNS labels.
|
||||||
QueryDoubleLabel
|
QueryMultiLabel
|
||||||
|
// QueryPlainLabel encodes channel and block as plain decimal text (no query encryption).
|
||||||
|
// Responses are always encrypted regardless of this setting.
|
||||||
|
QueryPlainLabel
|
||||||
)
|
)
|
||||||
|
|
||||||
var b32 = base32.StdEncoding.WithPadding(base32.NoPadding)
|
var b32 = base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||||
|
|
||||||
// EncodeQuery creates an encrypted DNS query subdomain.
|
// EncodeQuery creates a DNS query subdomain for the given channel and block.
|
||||||
// Single-label (default): [base32_encrypted].domain
|
// Single-label (default): [base32_encrypted].domain
|
||||||
// Double-label: [hex_part1].[hex_part2].domain
|
// Multi-label: [hex_part1].[hex_part2].domain
|
||||||
// Payload: 4 random + 2 channel + 2 block = 8 bytes, encrypted with AES-GCM.
|
// Plain-label: c<channel>b<block>.domain (no query encryption)
|
||||||
|
// Responses are always encrypted regardless of mode.
|
||||||
func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, mode QueryEncoding) (string, error) {
|
func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, mode QueryEncoding) (string, error) {
|
||||||
|
domain = strings.TrimSuffix(domain, ".")
|
||||||
|
if domain == "" {
|
||||||
|
return "", fmt.Errorf("empty domain")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plain text mode: no encryption, just human-readable label.
|
||||||
|
if mode == QueryPlainLabel {
|
||||||
|
label := fmt.Sprintf("c%db%d", channel, block)
|
||||||
|
return joinQName([]string{label}, domain)
|
||||||
|
}
|
||||||
|
|
||||||
payload := make([]byte, QueryPayloadSize)
|
payload := make([]byte, QueryPayloadSize)
|
||||||
|
|
||||||
if _, err := rand.Read(payload[:QueryPaddingSize]); err != nil {
|
if _, err := rand.Read(payload[:QueryPaddingSize]); err != nil {
|
||||||
@@ -36,24 +58,88 @@ func EncodeQuery(queryKey [KeySize]byte, channel, block uint16, domain string, m
|
|||||||
binary.BigEndian.PutUint16(payload[QueryPaddingSize:], channel)
|
binary.BigEndian.PutUint16(payload[QueryPaddingSize:], channel)
|
||||||
binary.BigEndian.PutUint16(payload[QueryPaddingSize+QueryChannelSize:], block)
|
binary.BigEndian.PutUint16(payload[QueryPaddingSize+QueryChannelSize:], block)
|
||||||
|
|
||||||
encrypted, err := Encrypt(queryKey, payload)
|
encrypted, err := encryptQueryBlock(queryKey, payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("encrypt query: %w", err)
|
return "", fmt.Errorf("encrypt query: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Append 0–4 random suffix bytes so query length varies per request.
|
||||||
|
// The decoder strips these by only using the first aes.BlockSize bytes.
|
||||||
|
suffixLen, _ := rand.Int(rand.Reader, big.NewInt(5)) // [0,4]
|
||||||
|
suffix := make([]byte, int(suffixLen.Int64()))
|
||||||
|
rand.Read(suffix) //nolint:errcheck — non-critical randomness
|
||||||
|
ciphertext := append(encrypted, suffix...)
|
||||||
|
|
||||||
switch mode {
|
switch mode {
|
||||||
case QueryDoubleLabel:
|
case QueryMultiLabel:
|
||||||
h := hex.EncodeToString(encrypted)
|
h := hex.EncodeToString(ciphertext)
|
||||||
mid := len(h) / 2
|
labels := splitMultiLabel(h)
|
||||||
return fmt.Sprintf("%s.%s.%s", h[:mid], h[mid:], domain), nil
|
return joinQName(labels, domain)
|
||||||
default:
|
default:
|
||||||
encoded := strings.ToLower(b32.EncodeToString(encrypted))
|
encoded := strings.ToLower(b32.EncodeToString(ciphertext))
|
||||||
return fmt.Sprintf("%s.%s", encoded, domain), nil
|
return joinQName([]string{encoded}, domain)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitLabel(s string, size int) []string {
|
||||||
|
if size <= 0 {
|
||||||
|
size = maxDNSLabelLen
|
||||||
|
}
|
||||||
|
if size > maxDNSLabelLen {
|
||||||
|
size = maxDNSLabelLen
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := make([]string, 0, (len(s)+size-1)/size)
|
||||||
|
for len(s) > size {
|
||||||
|
parts = append(parts, s[:size])
|
||||||
|
s = s[size:]
|
||||||
|
}
|
||||||
|
if len(s) > 0 {
|
||||||
|
parts = append(parts, s)
|
||||||
|
}
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitMultiLabel splits a hex string into two labels of randomised, unequal length.
|
||||||
|
// The first label is between 12 and (len-4) chars so the second is at least 4 chars.
|
||||||
|
// This makes query labels look less uniform across requests.
|
||||||
|
func splitMultiLabel(h string) []string {
|
||||||
|
if len(h) <= 8 {
|
||||||
|
return []string{h}
|
||||||
|
}
|
||||||
|
// first label: random length in [minFirst, len-4]
|
||||||
|
minFirst := 8
|
||||||
|
maxFirst := len(h) - 4
|
||||||
|
if maxFirst <= minFirst {
|
||||||
|
maxFirst = minFirst + 1
|
||||||
|
}
|
||||||
|
// crypto/rand for the split point; fall back to midpoint on error
|
||||||
|
split := (len(h) + 1) / 2 // default: slightly off-centre
|
||||||
|
if n, err := rand.Int(rand.Reader, big.NewInt(int64(maxFirst-minFirst+1))); err == nil {
|
||||||
|
split = minFirst + int(n.Int64())
|
||||||
|
}
|
||||||
|
return []string{h[:split], h[split:]}
|
||||||
|
}
|
||||||
|
|
||||||
|
func joinQName(labels []string, domain string) (string, error) {
|
||||||
|
for _, l := range labels {
|
||||||
|
if len(l) == 0 {
|
||||||
|
return "", fmt.Errorf("empty label")
|
||||||
|
}
|
||||||
|
if len(l) > maxDNSLabelLen {
|
||||||
|
return "", fmt.Errorf("label too long: %d", len(l))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
qname := strings.Join(append(labels, domain), ".")
|
||||||
|
if len(qname) > maxDNSNameLen {
|
||||||
|
return "", fmt.Errorf("query name too long: %d", len(qname))
|
||||||
|
}
|
||||||
|
return qname, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DecodeQuery parses and decrypts a DNS query subdomain.
|
// DecodeQuery parses and decrypts a DNS query subdomain.
|
||||||
// Auto-detects single-label (base32) or double-label (hex) encoding.
|
// Auto-detects plain-text (c<N>b<M>), single-label base32, or multi-label hex encoding.
|
||||||
func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block uint16, err error) {
|
func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block uint16, err error) {
|
||||||
qname = strings.TrimSuffix(qname, ".")
|
qname = strings.TrimSuffix(qname, ".")
|
||||||
domain = strings.TrimSuffix(domain, ".")
|
domain = strings.TrimSuffix(domain, ".")
|
||||||
@@ -65,32 +151,54 @@ func DecodeQuery(queryKey [KeySize]byte, qname, domain string) (channel, block u
|
|||||||
|
|
||||||
encoded := qname[:len(qname)-len(suffix)]
|
encoded := qname[:len(qname)-len(suffix)]
|
||||||
|
|
||||||
// Try base32 first (single label, no dots, or dots stripped)
|
// Try plain-label first: c<channel>b<block> (short, no dots, all decimal)
|
||||||
b32str := strings.ReplaceAll(encoded, ".", "")
|
if ch, blk, ok := parsePlainLabel(encoded); ok {
|
||||||
ciphertext, err := b32.DecodeString(strings.ToUpper(b32str))
|
return ch, blk, nil
|
||||||
if err == nil {
|
|
||||||
return decryptQuery(queryKey, ciphertext)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fall back to hex (double-label)
|
// Try base32 (single-label: no dots or dots stripped)
|
||||||
|
b32str := strings.ReplaceAll(encoded, ".", "")
|
||||||
|
if ct, e := b32.DecodeString(strings.ToUpper(b32str)); e == nil {
|
||||||
|
return parseQueryCiphertext(queryKey, ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fall back to hex (multi-label: dots stripped)
|
||||||
hexStr := strings.ReplaceAll(encoded, ".", "")
|
hexStr := strings.ReplaceAll(encoded, ".", "")
|
||||||
ciphertext, err = hex.DecodeString(hexStr)
|
ct, e := hex.DecodeString(hexStr)
|
||||||
if err != nil {
|
if e != nil {
|
||||||
return 0, 0, fmt.Errorf("decode query: invalid encoding")
|
return 0, 0, fmt.Errorf("decode query: invalid encoding")
|
||||||
}
|
}
|
||||||
return decryptQuery(queryKey, ciphertext)
|
return parseQueryCiphertext(queryKey, ct)
|
||||||
}
|
}
|
||||||
|
|
||||||
func decryptQuery(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) {
|
// parsePlainLabel parses the plain-text query format "c<channel>b<block>".
|
||||||
plaintext, err := Decrypt(queryKey, ciphertext)
|
// Returns ok=false if the string does not match this pattern.
|
||||||
|
func parsePlainLabel(s string) (channel, block uint16, ok bool) {
|
||||||
|
if len(s) < 3 || s[0] != 'c' {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
bi := strings.IndexByte(s[1:], 'b')
|
||||||
|
if bi < 0 {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
bi++ // adjust for the slice offset
|
||||||
|
chStr, bStr := s[1:bi], s[bi+1:]
|
||||||
|
if len(chStr) == 0 || len(bStr) == 0 {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
ch, err1 := strconv.ParseUint(chStr, 10, 16)
|
||||||
|
blk, err2 := strconv.ParseUint(bStr, 10, 16)
|
||||||
|
if err1 != nil || err2 != nil {
|
||||||
|
return 0, 0, false
|
||||||
|
}
|
||||||
|
return uint16(ch), uint16(blk), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseQueryCiphertext(queryKey [KeySize]byte, ciphertext []byte) (channel, block uint16, err error) {
|
||||||
|
plaintext, err := decryptQueryBlock(queryKey, ciphertext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, fmt.Errorf("decrypt: %w", err)
|
return 0, 0, fmt.Errorf("decrypt: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(plaintext) != QueryPayloadSize {
|
|
||||||
return 0, 0, fmt.Errorf("invalid payload size: %d", len(plaintext))
|
|
||||||
}
|
|
||||||
|
|
||||||
channel = binary.BigEndian.Uint16(plaintext[QueryPaddingSize:])
|
channel = binary.BigEndian.Uint16(plaintext[QueryPaddingSize:])
|
||||||
block = binary.BigEndian.Uint16(plaintext[QueryPaddingSize+QueryChannelSize:])
|
block = binary.BigEndian.Uint16(plaintext[QueryPaddingSize+QueryChannelSize:])
|
||||||
return channel, block, nil
|
return channel, block, nil
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -43,21 +44,26 @@ func TestEncodeDecodeQuerySingleLabel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodeDecodeQueryDoubleLabel(t *testing.T) {
|
func TestEncodeDecodeQueryMultiLabel(t *testing.T) {
|
||||||
qk, _, err := DeriveKeys("test-key")
|
qk, _, err := DeriveKeys("test-key")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("DeriveKeys: %v", err)
|
t.Fatalf("DeriveKeys: %v", err)
|
||||||
}
|
}
|
||||||
domain := "t.example.com"
|
domain := "t.example.com"
|
||||||
qname, err := EncodeQuery(qk, 3, 7, domain, QueryDoubleLabel)
|
qname, err := EncodeQuery(qk, 3, 7, domain, QueryMultiLabel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
// Double label: two hex labels before domain
|
// Multi-label mode splits hex across labels; all must be DNS-safe.
|
||||||
subdomain := qname[:len(qname)-len(domain)-1]
|
subdomain := qname[:len(qname)-len(domain)-1]
|
||||||
parts := strings.Split(subdomain, ".")
|
parts := strings.Split(subdomain, ".")
|
||||||
if len(parts) != 2 {
|
if len(parts) < 1 {
|
||||||
t.Errorf("double-label query should have 2 parts, got %d: %q", len(parts), subdomain)
|
t.Errorf("multi-label query should have at least 1 part, got %d: %q", len(parts), subdomain)
|
||||||
|
}
|
||||||
|
for _, p := range parts {
|
||||||
|
if len(p) == 0 || len(p) > 63 {
|
||||||
|
t.Errorf("invalid label length %d in %q", len(p), p)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ch, blk, err := DecodeQuery(qk, qname, domain)
|
ch, blk, err := DecodeQuery(qk, qname, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -68,6 +74,73 @@ func TestEncodeDecodeQueryDoubleLabel(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestEncodeQueryTooLongDomain(t *testing.T) {
|
||||||
|
qk, _, err := DeriveKeys("test-key")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKeys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 250-char domain should make qname exceed DNS 253-char limit.
|
||||||
|
longDomain := strings.Repeat("a", 250)
|
||||||
|
_, err = EncodeQuery(qk, 1, 1, longDomain, QueryMultiLabel)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for too-long domain")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncodeDecodeQueryPlainLabel(t *testing.T) {
|
||||||
|
qk, _, err := DeriveKeys("test-key")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKeys: %v", err)
|
||||||
|
}
|
||||||
|
domain := "t.example.com"
|
||||||
|
tests := []struct {
|
||||||
|
channel uint16
|
||||||
|
block uint16
|
||||||
|
}{
|
||||||
|
{0, 0},
|
||||||
|
{1, 42},
|
||||||
|
{255, 65535},
|
||||||
|
{3, 100},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
qname, err := EncodeQuery(qk, tt.channel, tt.block, domain, QueryPlainLabel)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("EncodeQuery(%d, %d): %v", tt.channel, tt.block, err)
|
||||||
|
}
|
||||||
|
// Label should be "c<channel>b<block>" — human readable, no padding hex.
|
||||||
|
want := fmt.Sprintf("c%db%d.%s", tt.channel, tt.block, domain)
|
||||||
|
if qname != want {
|
||||||
|
t.Errorf("got %q, want %q", qname, want)
|
||||||
|
}
|
||||||
|
// DecodeQuery must recover channel and block regardless of key.
|
||||||
|
ch, blk, err := DecodeQuery(qk, qname, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecodeQuery: %v", err)
|
||||||
|
}
|
||||||
|
if ch != tt.channel || blk != tt.block {
|
||||||
|
t.Errorf("got ch=%d blk=%d, want ch=%d blk=%d", ch, blk, tt.channel, tt.block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlainLabelNotConfusedWithEncrypted(t *testing.T) {
|
||||||
|
qk, _, err := DeriveKeys("test-key")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeriveKeys: %v", err)
|
||||||
|
}
|
||||||
|
domain := "t.example.com"
|
||||||
|
// Encode with single-label then check that DecodeQuery does NOT treat it as plain.
|
||||||
|
qname, _ := EncodeQuery(qk, 5, 10, domain, QuerySingleLabel)
|
||||||
|
ch, blk, err := DecodeQuery(qk, qname, domain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DecodeQuery single-label: %v", err)
|
||||||
|
}
|
||||||
|
if ch != 5 || blk != 10 {
|
||||||
|
t.Errorf("got ch=%d blk=%d, want ch=5 blk=10", ch, blk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDecodeQueryWrongKey(t *testing.T) {
|
func TestDecodeQueryWrongKey(t *testing.T) {
|
||||||
qk1, _, _ := DeriveKeys("key1")
|
qk1, _, _ := DeriveKeys("key1")
|
||||||
qk2, _, _ := DeriveKeys("key2")
|
qk2, _, _ := DeriveKeys("key2")
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
package protocol
|
package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/big"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// DefaultBlockPayload is the decrypted payload per DNS TXT block.
|
// MinBlockPayload is the minimum decrypted payload per DNS TXT block.
|
||||||
// Calculated to stay within 512-byte UDP DNS limit after encryption + base64 + padding overhead.
|
MinBlockPayload = 200
|
||||||
DefaultBlockPayload = 180
|
// MaxBlockPayload is the maximum decrypted payload per DNS TXT block.
|
||||||
|
// 600 bytes data + 28 GCM overhead + 2 prefix + 32 padding → ~856 base64 chars.
|
||||||
|
// Well within the 4096-byte EDNS0 UDP buffer the client advertises.
|
||||||
|
MaxBlockPayload = 600
|
||||||
|
// DefaultBlockPayload is kept for compatibility; equals MaxBlockPayload.
|
||||||
|
DefaultBlockPayload = MaxBlockPayload
|
||||||
|
|
||||||
// DefaultMaxPadding is the default random padding added to responses to vary DNS response size.
|
// DefaultMaxPadding is the default random padding added to responses to vary DNS response size.
|
||||||
DefaultMaxPadding = 32
|
DefaultMaxPadding = 32
|
||||||
@@ -210,20 +217,31 @@ func ParseMessages(data []byte) ([]Message, error) {
|
|||||||
return msgs, nil
|
return msgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SplitIntoBlocks splits data into blocks of DefaultBlockPayload size.
|
// SplitIntoBlocks splits data into blocks of randomly varying size in [MinBlockPayload, MaxBlockPayload].
|
||||||
|
// Random sizes make traffic analysis harder; the client just concatenates all blocks to reassemble.
|
||||||
func SplitIntoBlocks(data []byte) [][]byte {
|
func SplitIntoBlocks(data []byte) [][]byte {
|
||||||
if len(data) == 0 {
|
if len(data) == 0 {
|
||||||
return [][]byte{{}}
|
return [][]byte{{}} // channel 0 block 0 must always exist
|
||||||
}
|
}
|
||||||
var blocks [][]byte
|
var blocks [][]byte
|
||||||
for i := 0; i < len(data); i += DefaultBlockPayload {
|
rem := data
|
||||||
end := i + DefaultBlockPayload
|
for len(rem) > 0 {
|
||||||
if end > len(data) {
|
size := randBlockSize()
|
||||||
end = len(data)
|
if size > len(rem) {
|
||||||
|
size = len(rem)
|
||||||
}
|
}
|
||||||
block := make([]byte, end-i)
|
block := make([]byte, size)
|
||||||
copy(block, data[i:end])
|
copy(block, rem[:size])
|
||||||
blocks = append(blocks, block)
|
blocks = append(blocks, block)
|
||||||
|
rem = rem[size:]
|
||||||
}
|
}
|
||||||
return blocks
|
return blocks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randBlockSize() int {
|
||||||
|
n, err := rand.Int(rand.Reader, big.NewInt(int64(MaxBlockPayload-MinBlockPayload+1)))
|
||||||
|
if err != nil {
|
||||||
|
return (MinBlockPayload + MaxBlockPayload) / 2
|
||||||
|
}
|
||||||
|
return MinBlockPayload + int(n.Int64())
|
||||||
|
}
|
||||||
|
|||||||
@@ -57,19 +57,19 @@ func TestSerializeParseMessages(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSplitIntoBlocks(t *testing.T) {
|
func TestSplitIntoBlocks(t *testing.T) {
|
||||||
data := bytes.Repeat([]byte("A"), DefaultBlockPayload*3+50)
|
// MaxBlockPayload*3+50 guarantees at least 4 blocks (ceil((MaxBlockPayload*3+50)/MaxBlockPayload) = 4).
|
||||||
|
data := bytes.Repeat([]byte("A"), MaxBlockPayload*3+50)
|
||||||
blocks := SplitIntoBlocks(data)
|
blocks := SplitIntoBlocks(data)
|
||||||
if len(blocks) != 4 {
|
if len(blocks) < 4 {
|
||||||
t.Fatalf("blocks: got %d, want 4", len(blocks))
|
t.Fatalf("expected at least 4 blocks for %d bytes, got %d", len(data), len(blocks))
|
||||||
}
|
}
|
||||||
for i, b := range blocks {
|
// Every non-last block must be within [MinBlockPayload, MaxBlockPayload].
|
||||||
if i < 3 && len(b) != DefaultBlockPayload {
|
for i, b := range blocks[:len(blocks)-1] {
|
||||||
t.Errorf("block %d: size %d, want %d", i, len(b), DefaultBlockPayload)
|
if len(b) < MinBlockPayload || len(b) > MaxBlockPayload {
|
||||||
|
t.Errorf("block %d: size %d, want [%d, %d]", i, len(b), MinBlockPayload, MaxBlockPayload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(blocks[3]) != 50 {
|
// Reassembled data must equal original.
|
||||||
t.Errorf("last block: size %d, want 50", len(blocks[3]))
|
|
||||||
}
|
|
||||||
var reassembled []byte
|
var reassembled []byte
|
||||||
for _, b := range blocks {
|
for _, b := range blocks {
|
||||||
reassembled = append(reassembled, b...)
|
reassembled = append(reassembled, b...)
|
||||||
|
|||||||
@@ -71,7 +71,9 @@ func TestFeedGetBlockUnknownChannel(t *testing.T) {
|
|||||||
|
|
||||||
func TestFeedLargeMessages(t *testing.T) {
|
func TestFeedLargeMessages(t *testing.T) {
|
||||||
feed := NewFeed([]string{"Test"})
|
feed := NewFeed([]string{"Test"})
|
||||||
largeText := make([]byte, 500)
|
// Use text large enough to span 2 blocks at DefaultBlockPayload (currently 700 bytes).
|
||||||
|
// Message serialization overhead is 10 bytes, so we need >690 bytes of text.
|
||||||
|
largeText := make([]byte, 750)
|
||||||
for i := range largeText {
|
for i := range largeText {
|
||||||
largeText[i] = 65
|
largeText[i] = 65
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func NewTelegramReader(cfg TelegramConfig, channelUsernames []string, feed *Feed
|
|||||||
channels: cleaned,
|
channels: cleaned,
|
||||||
feed: feed,
|
feed: feed,
|
||||||
cache: make(map[string]cachedMessages),
|
cache: make(map[string]cachedMessages),
|
||||||
cacheTTL: 1 * time.Minute,
|
cacheTTL: 5 * time.Minute,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +115,7 @@ func (tr *TelegramReader) Run(ctx context.Context) error {
|
|||||||
tr.fetchAll(ctx, api)
|
tr.fetchAll(ctx, api)
|
||||||
|
|
||||||
// Periodic fetch loop
|
// Periodic fetch loop
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(3 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|||||||
+165
-17
@@ -142,6 +142,44 @@
|
|||||||
word-break: break-word;
|
word-break: break-word;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.progress-panel {
|
||||||
|
height: 56px;
|
||||||
|
min-height: 56px;
|
||||||
|
background: var(--surface);
|
||||||
|
border-top: 1px solid var(--border);
|
||||||
|
border-bottom: 1px solid var(--border);
|
||||||
|
overflow-y: auto;
|
||||||
|
font-size: 12px;
|
||||||
|
font-family: 'Vazirmatn', monospace;
|
||||||
|
padding: 12px 12px;
|
||||||
|
direction: ltr;
|
||||||
|
text-align: left;
|
||||||
|
color: var(--text-dim);
|
||||||
|
}
|
||||||
|
.progress-item {
|
||||||
|
margin-bottom: 8px;
|
||||||
|
padding: 6px;
|
||||||
|
border-radius: 4px;
|
||||||
|
background: var(--bg);
|
||||||
|
}
|
||||||
|
.progress-label {
|
||||||
|
font-size: 11px;
|
||||||
|
margin-bottom: 3px;
|
||||||
|
color: var(--text-dim);
|
||||||
|
}
|
||||||
|
.progress-bar {
|
||||||
|
width: 100%;
|
||||||
|
height: 6px;
|
||||||
|
background: var(--border);
|
||||||
|
border-radius: 3px;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
.progress-fill {
|
||||||
|
height: 100%;
|
||||||
|
background: var(--accent);
|
||||||
|
transition: width 0.2s;
|
||||||
|
}
|
||||||
|
|
||||||
.log-panel {
|
.log-panel {
|
||||||
height: 150px;
|
height: 150px;
|
||||||
min-height: 100px;
|
min-height: 100px;
|
||||||
@@ -149,13 +187,21 @@
|
|||||||
border-top: 1px solid var(--border);
|
border-top: 1px solid var(--border);
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
font-size: 12px;
|
font-size: 12px;
|
||||||
font-family: 'Vazirmatn', monospace;
|
font-family: monospace;
|
||||||
padding: 8px 12px;
|
padding: 8px 12px;
|
||||||
direction: ltr;
|
direction: ltr;
|
||||||
text-align: left;
|
text-align: left;
|
||||||
color: var(--text-dim);
|
color: var(--text-dim);
|
||||||
}
|
}
|
||||||
.log-line { padding: 2px 0; }
|
.log-line {
|
||||||
|
padding: 2px 0;
|
||||||
|
line-height: 1.3;
|
||||||
|
}
|
||||||
|
.log-line.info { color: #60a5fa; }
|
||||||
|
.log-line.progress { color: #fbbf24; }
|
||||||
|
.log-line.error { color: #ef4444; }
|
||||||
|
.log-line.success { color: #22c55e; }
|
||||||
|
.log-line.warning { color: #f97316; }
|
||||||
|
|
||||||
.modal-overlay {
|
.modal-overlay {
|
||||||
display: none;
|
display: none;
|
||||||
@@ -245,6 +291,7 @@
|
|||||||
<button class="btn btn-primary" onclick="openSettings()">Configure</button>
|
<button class="btn btn-primary" onclick="openSettings()">Configure</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="progress-panel" id="progressPanel"></div>
|
||||||
<div class="log-panel" id="logPanel"></div>
|
<div class="log-panel" id="logPanel"></div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -268,12 +315,17 @@
|
|||||||
<label>Query Mode</label>
|
<label>Query Mode</label>
|
||||||
<select id="cfgQueryMode">
|
<select id="cfgQueryMode">
|
||||||
<option value="single">Single label (base32, stealthier)</option>
|
<option value="single">Single label (base32, stealthier)</option>
|
||||||
<option value="double">Double label (hex)</option>
|
<option value="double">Multi-label (hex)</option>
|
||||||
|
<option value="plain">Plain text (no query encryption)</option>
|
||||||
</select>
|
</select>
|
||||||
</div>
|
</div>
|
||||||
<div class="form-group">
|
<div class="form-group">
|
||||||
<label>Rate Limit (queries/sec, 0 = unlimited)</label>
|
<label>Rate Limit (queries/sec, 0 = unlimited)</label>
|
||||||
<input type="number" id="cfgRateLimit" value="0" min="0" step="0.1">
|
<input type="number" id="cfgRateLimit" value="100" min="0" step="0.1">
|
||||||
|
</div>
|
||||||
|
<div class="form-group" style="flex-direction:row;align-items:center;gap:10px;">
|
||||||
|
<input type="checkbox" id="cfgDebug" style="width:auto;margin:0;">
|
||||||
|
<label for="cfgDebug" style="margin:0;cursor:pointer;">Debug mode (log generated query names)</label>
|
||||||
</div>
|
</div>
|
||||||
<div style="display:flex;gap:8px;justify-content:flex-end;margin-top:20px;">
|
<div style="display:flex;gap:8px;justify-content:flex-end;margin-top:20px;">
|
||||||
<button class="btn" onclick="closeSettings()">Cancel</button>
|
<button class="btn" onclick="closeSettings()">Cancel</button>
|
||||||
@@ -287,6 +339,7 @@
|
|||||||
let selectedChannel = 0;
|
let selectedChannel = 0;
|
||||||
let channels = [];
|
let channels = [];
|
||||||
let eventSource = null;
|
let eventSource = null;
|
||||||
|
let autoRefreshTimer = null;
|
||||||
|
|
||||||
async function init() {
|
async function init() {
|
||||||
try {
|
try {
|
||||||
@@ -296,11 +349,16 @@
|
|||||||
openSettings();
|
openSettings();
|
||||||
} else {
|
} else {
|
||||||
document.getElementById('statusDot').className = 'status-dot connected';
|
document.getElementById('statusDot').className = 'status-dot connected';
|
||||||
if (status.channels && status.channels.length > 0) {
|
await doRefresh();
|
||||||
channels = status.channels;
|
await loadChannels();
|
||||||
renderChannels();
|
if (channels && channels.length > 0) {
|
||||||
selectChannel(1);
|
await selectChannel(1);
|
||||||
}
|
}
|
||||||
|
autoRefreshTimer = setInterval(function() {
|
||||||
|
if (selectedChannel > 0) {
|
||||||
|
doRefresh();
|
||||||
|
}
|
||||||
|
}, 120000);
|
||||||
}
|
}
|
||||||
} catch (e) {}
|
} catch (e) {}
|
||||||
connectSSE();
|
connectSSE();
|
||||||
@@ -312,12 +370,23 @@
|
|||||||
eventSource.addEventListener('log', function(e) {
|
eventSource.addEventListener('log', function(e) {
|
||||||
addLogLine(JSON.parse(e.data));
|
addLogLine(JSON.parse(e.data));
|
||||||
});
|
});
|
||||||
eventSource.addEventListener('update', function(e) {
|
eventSource.addEventListener('update', async function(e) {
|
||||||
loadChannels();
|
var wasEmpty = channels.length === 0;
|
||||||
if (selectedChannel > 0) loadMessages(selectedChannel);
|
await loadChannels();
|
||||||
|
if (wasEmpty && channels.length > 0 && selectedChannel === 0) {
|
||||||
|
// Channels just appeared for the first time — auto-select the first one.
|
||||||
|
selectChannel(1);
|
||||||
|
} else if (selectedChannel > 0) {
|
||||||
|
loadMessages(selectedChannel);
|
||||||
|
}
|
||||||
});
|
});
|
||||||
eventSource.onerror = function() {
|
eventSource.onerror = function() {
|
||||||
document.getElementById('statusDot').className = 'status-dot disconnected';
|
document.getElementById('statusDot').className = 'status-dot disconnected';
|
||||||
|
// EventSource.CLOSED (2) means the browser stopped retrying — reconnect manually.
|
||||||
|
if (eventSource.readyState === EventSource.CLOSED) {
|
||||||
|
eventSource.close();
|
||||||
|
setTimeout(connectSSE, 3000);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
eventSource.onopen = function() {
|
eventSource.onopen = function() {
|
||||||
document.getElementById('statusDot').className = 'status-dot connected';
|
document.getElementById('statusDot').className = 'status-dot connected';
|
||||||
@@ -343,6 +412,7 @@
|
|||||||
async function selectChannel(num) {
|
async function selectChannel(num) {
|
||||||
selectedChannel = num;
|
selectedChannel = num;
|
||||||
renderChannels();
|
renderChannels();
|
||||||
|
await doRefresh();
|
||||||
await loadMessages(num);
|
await loadMessages(num);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,16 +462,79 @@
|
|||||||
var el = document.getElementById('logPanel');
|
var el = document.getElementById('logPanel');
|
||||||
var div = document.createElement('div');
|
var div = document.createElement('div');
|
||||||
div.className = 'log-line';
|
div.className = 'log-line';
|
||||||
|
|
||||||
|
// Parse log level from line content
|
||||||
|
var level = 'info';
|
||||||
|
var displayText = line;
|
||||||
|
|
||||||
|
if (typeof line === 'string') {
|
||||||
|
if (line.includes('Error:') || line.includes('error') || line.includes('Error')) {
|
||||||
|
level = 'error';
|
||||||
|
} else if (line.includes('Warning:') || line.includes('warning') || line.includes('Warning')) {
|
||||||
|
level = 'warning';
|
||||||
|
} else if (line.includes('OK:') || line.includes('OK') || line.includes('success') || line.includes('done')) {
|
||||||
|
level = 'success';
|
||||||
|
} else if (line.match(/\[\d+%\]|\d+\s*%/)) {
|
||||||
|
level = 'progress';
|
||||||
|
updateProgressDisplay(line);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
div.className = 'log-line ' + level;
|
||||||
div.textContent = line;
|
div.textContent = line;
|
||||||
el.appendChild(div);
|
el.appendChild(div);
|
||||||
el.scrollTop = el.scrollHeight;
|
el.scrollTop = el.scrollHeight;
|
||||||
|
|
||||||
|
// Keep only last 200 lines
|
||||||
while (el.children.length > 200) {
|
while (el.children.length > 200) {
|
||||||
el.removeChild(el.firstChild);
|
el.removeChild(el.firstChild);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function updateProgressDisplay(line) {
|
||||||
|
// Parse progress from "Channel N [====> ] 45%"
|
||||||
|
var match = line.match(/Channel\s+(\d+)/);
|
||||||
|
if (!match) return;
|
||||||
|
|
||||||
|
var channelNum = match[1];
|
||||||
|
var percentMatch = line.match(/(\d+)%/);
|
||||||
|
var percent = percentMatch ? parseInt(percentMatch[1]) : 0;
|
||||||
|
|
||||||
|
var panel = document.getElementById('progressPanel');
|
||||||
|
var itemId = 'progress-current';
|
||||||
|
var item = document.getElementById(itemId);
|
||||||
|
|
||||||
|
if (!item) {
|
||||||
|
item = document.createElement('div');
|
||||||
|
item.id = itemId;
|
||||||
|
item.className = 'progress-item';
|
||||||
|
item.innerHTML = '<div class="progress-label">Channel ' + channelNum + '</div>' +
|
||||||
|
'<div class="progress-bar"><div class="progress-fill" style="width:0%"></div></div>';
|
||||||
|
panel.appendChild(item);
|
||||||
|
} else {
|
||||||
|
item.querySelector('.progress-label').textContent = 'Channel ' + channelNum;
|
||||||
|
}
|
||||||
|
|
||||||
|
var fill = item.querySelector('.progress-fill');
|
||||||
|
fill.style.width = percent + '%';
|
||||||
|
|
||||||
|
// Remove if complete
|
||||||
|
if (percent >= 100) {
|
||||||
|
setTimeout(function() {
|
||||||
|
if (item.parentNode) item.parentNode.removeChild(item);
|
||||||
|
}, 1000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function doRefresh() {
|
async function doRefresh() {
|
||||||
try { await fetch('/api/refresh', {method: 'POST'}); } catch (e) {}
|
try {
|
||||||
|
var url = '/api/refresh';
|
||||||
|
if (selectedChannel > 0) {
|
||||||
|
url += '?channel=' + selectedChannel;
|
||||||
|
}
|
||||||
|
await fetch(url, {method: 'POST'});
|
||||||
|
} catch (e) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
function openSettings() {
|
function openSettings() {
|
||||||
@@ -411,7 +544,12 @@
|
|||||||
if (cfg.key) document.getElementById('cfgKey').value = cfg.key;
|
if (cfg.key) document.getElementById('cfgKey').value = cfg.key;
|
||||||
if (cfg.resolvers) document.getElementById('cfgResolvers').value = cfg.resolvers.join('\n');
|
if (cfg.resolvers) document.getElementById('cfgResolvers').value = cfg.resolvers.join('\n');
|
||||||
if (cfg.queryMode) document.getElementById('cfgQueryMode').value = cfg.queryMode;
|
if (cfg.queryMode) document.getElementById('cfgQueryMode').value = cfg.queryMode;
|
||||||
if (cfg.rateLimit) document.getElementById('cfgRateLimit').value = cfg.rateLimit;
|
if (typeof cfg.rateLimit === 'number') {
|
||||||
|
document.getElementById('cfgRateLimit').value = cfg.rateLimit;
|
||||||
|
} else {
|
||||||
|
document.getElementById('cfgRateLimit').value = 100;
|
||||||
|
}
|
||||||
|
document.getElementById('cfgDebug').checked = !!cfg.debug;
|
||||||
}).catch(function() {});
|
}).catch(function() {});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -432,7 +570,8 @@
|
|||||||
key: document.getElementById('cfgKey').value,
|
key: document.getElementById('cfgKey').value,
|
||||||
resolvers: resolvers,
|
resolvers: resolvers,
|
||||||
queryMode: document.getElementById('cfgQueryMode').value,
|
queryMode: document.getElementById('cfgQueryMode').value,
|
||||||
rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 0
|
rateLimit: parseFloat(document.getElementById('cfgRateLimit').value) || 100,
|
||||||
|
debug: document.getElementById('cfgDebug').checked
|
||||||
};
|
};
|
||||||
|
|
||||||
if (!cfg.domain || !cfg.key || resolvers.length === 0) {
|
if (!cfg.domain || !cfg.key || resolvers.length === 0) {
|
||||||
@@ -453,9 +592,18 @@
|
|||||||
}
|
}
|
||||||
closeSettings();
|
closeSettings();
|
||||||
document.getElementById('statusDot').className = 'status-dot connected';
|
document.getElementById('statusDot').className = 'status-dot connected';
|
||||||
setTimeout(function() {
|
await doRefresh();
|
||||||
loadChannels();
|
await loadChannels();
|
||||||
}, 2000);
|
if (channels && channels.length > 0) {
|
||||||
|
await selectChannel(1);
|
||||||
|
}
|
||||||
|
if (!autoRefreshTimer) {
|
||||||
|
autoRefreshTimer = setInterval(function() {
|
||||||
|
if (selectedChannel > 0) {
|
||||||
|
fetch('/api/refresh?channel=' + selectedChannel + '&quiet=1', {method: 'POST'});
|
||||||
|
}
|
||||||
|
}, 120000);
|
||||||
|
}
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
errEl.textContent = e.message;
|
errEl.textContent = e.message;
|
||||||
errEl.style.display = 'block';
|
errEl.style.display = 'block';
|
||||||
|
|||||||
+176
-50
@@ -1,6 +1,7 @@
|
|||||||
package web
|
package web
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"embed"
|
"embed"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -29,6 +30,11 @@ type Config struct {
|
|||||||
Resolvers []string `json:"resolvers"`
|
Resolvers []string `json:"resolvers"`
|
||||||
QueryMode string `json:"queryMode"`
|
QueryMode string `json:"queryMode"`
|
||||||
RateLimit float64 `json:"rateLimit"`
|
RateLimit float64 `json:"rateLimit"`
|
||||||
|
// Timeout is the per-query DNS timeout in seconds (0 = default 5 s).
|
||||||
|
// Also used as the resolver health-check probe timeout.
|
||||||
|
Timeout float64 `json:"timeout,omitempty"`
|
||||||
|
// Debug enables verbose query logging (shows generated DNS query names).
|
||||||
|
Debug bool `json:"debug,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server is the web UI server for thefeed client.
|
// Server is the web UI server for thefeed client.
|
||||||
@@ -43,6 +49,16 @@ type Server struct {
|
|||||||
channels []protocol.ChannelInfo
|
channels []protocol.ChannelInfo
|
||||||
messages map[int][]protocol.Message
|
messages map[int][]protocol.Message
|
||||||
|
|
||||||
|
// fetcherCtx/fetcherCancel control the lifetime of the active fetcher's
|
||||||
|
// background goroutines (rate limiter, noise, resolver checker).
|
||||||
|
// They are cancelled and recreated each time the config changes.
|
||||||
|
fetcherCtx context.Context
|
||||||
|
fetcherCancel context.CancelFunc
|
||||||
|
|
||||||
|
// refreshMu / refreshCancel allow a new refresh to cancel an in-progress one.
|
||||||
|
refreshMu sync.Mutex
|
||||||
|
refreshCancel context.CancelFunc
|
||||||
|
|
||||||
logMu sync.RWMutex
|
logMu sync.RWMutex
|
||||||
logLines []string
|
logLines []string
|
||||||
|
|
||||||
@@ -96,7 +112,7 @@ func (s *Server) Run() error {
|
|||||||
fmt.Printf("\n Open in browser: http://%s\n\n", addr)
|
fmt.Printf("\n Open in browser: http://%s\n\n", addr)
|
||||||
|
|
||||||
if s.fetcher != nil {
|
if s.fetcher != nil {
|
||||||
s.startAutoRefresh()
|
go s.refreshMetadataOnly()
|
||||||
}
|
}
|
||||||
|
|
||||||
return http.ListenAndServe(addr, mux)
|
return http.ListenAndServe(addr, mux)
|
||||||
@@ -164,7 +180,7 @@ func (s *Server) handleConfig(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, fmt.Sprintf("init fetcher: %v", err), 500)
|
http.Error(w, fmt.Sprintf("init fetcher: %v", err), 500)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.startAutoRefresh()
|
go s.refreshMetadataOnly()
|
||||||
writeJSON(w, map[string]any{"ok": true})
|
writeJSON(w, map[string]any{"ok": true})
|
||||||
|
|
||||||
default:
|
default:
|
||||||
@@ -202,7 +218,28 @@ func (s *Server) handleRefresh(w http.ResponseWriter, r *http.Request) {
|
|||||||
http.Error(w, "method not allowed", 405)
|
http.Error(w, "method not allowed", 405)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go s.refresh()
|
// Background (quiet) refreshes skip silently if one is already running,
|
||||||
|
// so the auto-refresh timer never cancels a slow in-progress fetch.
|
||||||
|
if r.URL.Query().Get("quiet") == "1" {
|
||||||
|
s.refreshMu.Lock()
|
||||||
|
running := s.refreshCancel != nil
|
||||||
|
s.refreshMu.Unlock()
|
||||||
|
if running {
|
||||||
|
writeJSON(w, map[string]any{"ok": true, "skipped": true})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chParam := r.URL.Query().Get("channel")
|
||||||
|
if chParam != "" {
|
||||||
|
chNum, err := strconv.Atoi(chParam)
|
||||||
|
if err != nil || chNum < 1 {
|
||||||
|
http.Error(w, "invalid channel", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go s.refreshChannel(chNum)
|
||||||
|
} else {
|
||||||
|
go s.refreshMetadataOnly()
|
||||||
|
}
|
||||||
writeJSON(w, map[string]any{"ok": true})
|
writeJSON(w, map[string]any{"ok": true})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,6 +315,11 @@ func (s *Server) initFetcher() error {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
// Cancel goroutines from the previous fetcher configuration.
|
||||||
|
if s.fetcherCancel != nil {
|
||||||
|
s.fetcherCancel()
|
||||||
|
}
|
||||||
|
|
||||||
cfg := s.config
|
cfg := s.config
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return fmt.Errorf("no config")
|
return fmt.Errorf("no config")
|
||||||
@@ -295,57 +337,81 @@ func (s *Server) initFetcher() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if cfg.QueryMode == "double" {
|
if cfg.QueryMode == "double" {
|
||||||
fetcher.SetQueryMode(protocol.QueryDoubleLabel)
|
fetcher.SetQueryMode(protocol.QueryMultiLabel)
|
||||||
|
} else if cfg.QueryMode == "plain" {
|
||||||
|
fetcher.SetQueryMode(protocol.QueryPlainLabel)
|
||||||
}
|
}
|
||||||
|
fetcher.SetDebug(cfg.Debug)
|
||||||
if cfg.RateLimit > 0 {
|
if cfg.RateLimit > 0 {
|
||||||
fetcher.SetRateLimit(cfg.RateLimit)
|
fetcher.SetRateLimit(cfg.RateLimit)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
timeout := 5 * time.Second
|
||||||
|
if cfg.Timeout > 0 {
|
||||||
|
timeout = time.Duration(cfg.Timeout * float64(time.Second))
|
||||||
|
}
|
||||||
|
fetcher.SetTimeout(timeout)
|
||||||
|
|
||||||
fetcher.SetLogFunc(func(msg string) {
|
fetcher.SetLogFunc(func(msg string) {
|
||||||
s.addLog(msg)
|
s.addLog(msg)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Create a shared context for this fetcher's lifetime.
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
s.fetcherCtx = ctx
|
||||||
|
s.fetcherCancel = cancel
|
||||||
|
|
||||||
|
// Start rate limiter and noise goroutines.
|
||||||
|
fetcher.Start(ctx)
|
||||||
|
|
||||||
|
// Start periodic resolver health checks.
|
||||||
|
checker := client.NewResolverChecker(fetcher, timeout)
|
||||||
|
checker.SetLogFunc(func(msg string) {
|
||||||
|
s.addLog(msg)
|
||||||
|
})
|
||||||
|
checker.Start(ctx)
|
||||||
|
|
||||||
s.fetcher = fetcher
|
s.fetcher = fetcher
|
||||||
s.cache = cache
|
s.cache = cache
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) startAutoRefresh() {
|
func (s *Server) refreshMetadataOnly() {
|
||||||
if s.stopRefresh != nil {
|
// Cancel any in-progress refresh and start a new cancellable one.
|
||||||
close(s.stopRefresh)
|
s.refreshMu.Lock()
|
||||||
|
if s.refreshCancel != nil {
|
||||||
|
s.refreshCancel()
|
||||||
}
|
}
|
||||||
s.stopRefresh = make(chan struct{})
|
|
||||||
stop := s.stopRefresh
|
|
||||||
|
|
||||||
go s.refresh()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-stop:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
s.refresh()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) refresh() {
|
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
|
basectx := s.fetcherCtx
|
||||||
fetcher := s.fetcher
|
fetcher := s.fetcher
|
||||||
cache := s.cache
|
cache := s.cache
|
||||||
s.mu.RUnlock()
|
s.mu.RUnlock()
|
||||||
|
|
||||||
if fetcher == nil {
|
if fetcher == nil || basectx == nil {
|
||||||
|
s.refreshMu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Child context: cancelled either by the next refresh call or by a config change.
|
||||||
|
ctx, cancel := context.WithCancel(basectx)
|
||||||
|
s.refreshCancel = cancel
|
||||||
|
s.refreshMu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
s.refreshMu.Lock()
|
||||||
|
s.refreshCancel = nil
|
||||||
|
s.refreshMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
s.addLog("Fetching metadata...")
|
s.addLog("Fetching metadata...")
|
||||||
meta, err := fetcher.FetchMetadata()
|
meta, err := fetcher.FetchMetadata(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
s.addLog("Refresh cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
s.addLog(fmt.Sprintf("Error: %v", err))
|
s.addLog(fmt.Sprintf("Error: %v", err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -359,31 +425,91 @@ func (s *Server) refresh() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.broadcast("event: update\ndata: \"channels\"\n\n")
|
s.broadcast("event: update\ndata: \"channels\"\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
for i, ch := range meta.Channels {
|
func (s *Server) refreshChannel(channelNum int) {
|
||||||
chNum := i + 1
|
s.refreshMu.Lock()
|
||||||
blockCount := int(ch.Blocks)
|
if s.refreshCancel != nil {
|
||||||
if blockCount <= 0 {
|
s.refreshCancel()
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs, err := fetcher.FetchChannel(chNum, blockCount)
|
|
||||||
if err != nil {
|
|
||||||
s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
s.mu.Lock()
|
|
||||||
s.messages[chNum] = msgs
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if cache != nil {
|
|
||||||
_ = cache.PutMessages(chNum, msgs)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.mu.RLock()
|
||||||
|
basectx := s.fetcherCtx
|
||||||
|
fetcher := s.fetcher
|
||||||
|
cache := s.cache
|
||||||
|
channels := s.channels
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if fetcher == nil || basectx == nil {
|
||||||
|
s.refreshMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(basectx)
|
||||||
|
s.refreshCancel = cancel
|
||||||
|
s.refreshMu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
cancel()
|
||||||
|
s.refreshMu.Lock()
|
||||||
|
s.refreshCancel = nil
|
||||||
|
s.refreshMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
meta, err := fetcher.FetchMetadata(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
s.addLog("Refresh cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.addLog(fmt.Sprintf("Error: %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.channels = meta.Channels
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
_ = cache.PutMetadata(meta)
|
||||||
|
}
|
||||||
|
s.broadcast("event: update\ndata: \"channels\"\n\n")
|
||||||
|
|
||||||
|
channels = meta.Channels
|
||||||
|
if channelNum < 1 || channelNum > len(channels) {
|
||||||
|
s.addLog(fmt.Sprintf("Warning: channel %d is not available", channelNum))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := channels[channelNum-1]
|
||||||
|
blockCount := int(ch.Blocks)
|
||||||
|
if blockCount <= 0 {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.messages[channelNum] = nil
|
||||||
|
s.mu.Unlock()
|
||||||
|
s.addLog(fmt.Sprintf("Updated %s: 0 messages", ch.Name))
|
||||||
|
s.broadcast("event: update\ndata: \"messages\"\n\n")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs, err := fetcher.FetchChannel(ctx, channelNum, blockCount)
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
s.addLog("Refresh cancelled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.addLog(fmt.Sprintf("Channel %s error: %v", ch.Name, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.messages[channelNum] = msgs
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
_ = cache.PutMessages(channelNum, msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.addLog(fmt.Sprintf("Updated %s: %d messages", ch.Name, len(msgs)))
|
||||||
s.broadcast("event: update\ndata: \"messages\"\n\n")
|
s.broadcast("event: update\ndata: \"messages\"\n\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user