commit ed991baa0f88808e9891aef98dcdcd67a08b7eb7 Author: ThisIsDara <1380katana@gmail.com> Date: Mon May 4 07:17:57 2026 +0330 Initial commit: MHR-CFW Go v1.1.0 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..517d397 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +.exe +*.exe~ +*.dll +*.so +*.dylib +*.test +*.out +go.work +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..1c627c6 --- /dev/null +++ b/README.md @@ -0,0 +1,59 @@ +# MHR-CFW Go + +A domain-fronted HTTP/HTTPS proxy relay suite for Apps Script written in Go. + +## Features + +- **HTTP Proxy** - Local proxy server with CONNECT tunnel support +- **SOCKS5 Proxy** - Built-in SOCKS5 server +- **Domain Fronting** - Google Apps Script relay via domain-fronted requests +- **MITM Proxy** - Dynamic certificate generation for HTTPS interception +- **HTTP/2 Transport** - Performance-optimized HTTP/2 connections +- **TUI Menu** - Interactive terminal menu for easy operation + +## Usage + +```powershell +# Run with TUI menu +.\mhr-cfw-go.exe + +# Run proxy directly (requires configured config.json) +.\mhr-cfw-go.exe --no-menu + +# Install CA certificate +.\mhr-cfw-go.exe --install-cert + +# Scan Google IPs +.\mhr-cfw-go.exe --scan + +# Show version +.\mhr-cfw-go.exe --version +``` + +## Configuration + +Edit `config.json` before running: + +```json +{ + "auth_key": "your-secret-key", + "script_id": "YOUR_APPS_SCRIPT_DEPLOYMENT_ID", + "front_domain": "www.google.com", + "google_ip": "216.239.38.120" +} +``` + +## Requirements + +- Go 1.21+ +- Windows (for certificate installation) + +## Building + +```powershell +go build -o mhr-cfw-go.exe ./cmd/mhr-cfw +``` + +## License + +MIT \ No newline at end of file diff --git a/cmd/mhr-cfw/main.go b/cmd/mhr-cfw/main.go new file mode 100644 index 0000000..78e4a02 --- /dev/null +++ b/cmd/mhr-cfw/main.go @@ -0,0 +1,296 @@ +package main + +import ( + "context" + "errors" + "flag" + "fmt" + "os" + "os/signal" + "path/filepath" + "strings" + "syscall" + "time" + + "github.com/denuitt1/mhr-cfw/internal/cert" + "github.com/denuitt1/mhr-cfw/internal/config" + "github.com/denuitt1/mhr-cfw/internal/constants" + "github.com/denuitt1/mhr-cfw/internal/lan" + "github.com/denuitt1/mhr-cfw/internal/logging" + "github.com/denuitt1/mhr-cfw/internal/mitm" + "github.com/denuitt1/mhr-cfw/internal/proxy" + "github.com/denuitt1/mhr-cfw/internal/scanner" + "github.com/denuitt1/mhr-cfw/internal/setup" + "github.com/denuitt1/mhr-cfw/internal/tui" +) + +var placeholderAuthKeys = map[string]bool{ + "": true, + "CHANGE_ME_TO_A_STRONG_SECRET": true, + "your-secret-password-here": true, +} + +type args struct { + configPath string + port int + host string + socksPort int + disableSocks bool + logLevel string + installCert bool + uninstallCert bool + noCertCheck bool + scan bool +} + +func parseArgs() (*args, error) { + a := &args{} + flag.StringVar(&a.configPath, "config", envOr("DFT_CONFIG", "config.json"), "Path to config file (default: config.json, env: DFT_CONFIG)") + flag.IntVar(&a.port, "port", 0, "Override listen port (env: DFT_PORT)") + flag.StringVar(&a.host, "host", "", "Override listen host (env: DFT_HOST)") + flag.IntVar(&a.socksPort, "socks5-port", 0, "Override SOCKS5 listen port (env: DFT_SOCKS5_PORT)") + flag.BoolVar(&a.disableSocks, "disable-socks5", false, "Disable the built-in SOCKS5 listener") + flag.StringVar(&a.logLevel, "log-level", "", "Override log level (env: DFT_LOG_LEVEL)") + flag.BoolVar(&a.installCert, "install-cert", false, "Install the MITM CA certificate as a trusted root and exit") + flag.BoolVar(&a.uninstallCert, "uninstall-cert", false, "Remove the MITM CA certificate from trusted roots and exit") + flag.BoolVar(&a.noCertCheck, "no-cert-check", false, "Skip the certificate installation check on startup") + flag.BoolVar(&a.scan, "scan", false, "Scan Google IPs to find the fastest reachable one and exit") + setupFlag := flag.Bool("setup", false, "Run interactive setup wizard and exit") + noMenu := flag.Bool("no-menu", false, "Run without the interactive TUI menu") + showVersion := flag.Bool("version", false, "Print version and exit") + flag.Parse() + + if *showVersion { + fmt.Printf("domainfront-tunnel %s\n", constants.Version) + os.Exit(0) + } + if *setupFlag { + if err := setup.RunInteractiveWizard(a.configPath); err != nil { + fmt.Println("Setup failed:", err) + os.Exit(1) + } + os.Exit(0) + } + if !*noMenu && isTTY(os.Stdin) { + if err := runMenu(a); err != nil { + fmt.Println("Menu error:", err) + os.Exit(1) + } + os.Exit(0) + } + return a, nil +} + +func main() { + if _, err := parseArgs(); err != nil { + fmt.Fprintln(os.Stderr, "args error:", err) + os.Exit(2) + } +} + +func runMenu(a *args) error { + menu := &tui.Menu{ + Title: "mhr-cfw", + Options: []tui.Option{ + {Key: 1, Label: "Start proxy", Handler: func() error { return runProxy(a) }}, + {Key: 2, Label: "Setup wizard", Handler: func() error { return setup.RunInteractiveWizard(a.configPath) }}, + {Key: 3, Label: "Install CA certificate", Handler: func() error { + logging.Configure("INFO") + if !fileExists(mitm.CACertFile) { + _ = mitm.NewManager() + } + if cert.InstallCA(mitm.CACertFile, cert.DefaultCertName) { + fmt.Println("[OK] CA installed") + return nil + } + return errors.New("CA install failed") + }}, + {Key: 4, Label: "Uninstall CA certificate", Handler: func() error { + logging.Configure("INFO") + if cert.UninstallCA(mitm.CACertFile, cert.DefaultCertName) { + fmt.Println("[OK] CA removed") + return nil + } + return errors.New("CA removal failed") + }}, + {Key: 5, Label: "Scan Google IPs", Handler: func() error { + cfg, err := config.Load(a.configPath) + if err != nil { + return err + } + logging.Configure("INFO") + frontDomain := cfg.GetString("front_domain", "www.google.com") + fmt.Println("\nScanning... this can take a minute on slow networks.") + ok := scanner.ScanSync(frontDomain) + if !ok { + return errors.New("no reachable IPs") + } + return nil + }}, + {Key: 6, Label: "Exit", Handler: nil}, + }, + } + return menu.Run() +} + +func runProxy(a *args) error { + if a.installCert || a.uninstallCert { + return nil + } + + cfgPath := a.configPath + cfg, err := config.Load(cfgPath) + if err != nil { + return err + } + + if v := os.Getenv("DFT_AUTH_KEY"); v != "" { + cfg.Set("auth_key", v) + } + if v := os.Getenv("DFT_SCRIPT_ID"); v != "" { + cfg.Set("script_id", v) + } + if v := os.Getenv("DFT_PORT"); v != "" { + cfg.Set("listen_port", config.ToInt(v, cfg.GetInt("listen_port", 8080))) + } + if v := os.Getenv("DFT_HOST"); v != "" { + cfg.Set("listen_host", v) + } + if v := os.Getenv("DFT_SOCKS5_PORT"); v != "" { + cfg.Set("socks5_port", config.ToInt(v, cfg.GetInt("socks5_port", 1080))) + } + if v := os.Getenv("DFT_LOG_LEVEL"); v != "" { + cfg.Set("log_level", v) + } + + if a.port != 0 { + cfg.Set("listen_port", a.port) + } + if a.host != "" { + cfg.Set("listen_host", a.host) + } + if a.socksPort != 0 { + cfg.Set("socks5_port", a.socksPort) + } + if a.disableSocks { + cfg.Set("socks5_enabled", false) + } + if a.logLevel != "" { + cfg.Set("log_level", a.logLevel) + } + + if placeholderAuthKeys[strings.TrimSpace(cfg.GetString("auth_key", ""))] { + return errors.New("refusing to start: auth_key is unset or placeholder") + } + + cfg.Set("mode", "apps_script") + sid := cfg.GetScriptID() + if sid == "" || sid == "YOUR_APPS_SCRIPT_DEPLOYMENT_ID" { + return errors.New("missing script_id in config") + } + + logging.Configure(cfg.GetString("log_level", "INFO")) + log := logging.Get("Main") + logging.PrintBanner(constants.Version) + log.Infof("DomainFront Tunnel starting (Apps Script relay)") + log.Infof("Apps Script relay : SNI=%s -> script.google.com", cfg.GetString("front_domain", "www.google.com")) + + if ids := cfg.GetScriptIDs(); len(ids) > 0 { + if len(ids) > 1 { + log.Infof("Script IDs : %d scripts (sticky per-host)", len(ids)) + for i, id := range ids { + log.Infof(" [%d] %s", i+1, id) + } + } else { + log.Infof("Script ID : %s", ids[0]) + } + } + + if !fileExists(mitm.CACertFile) { + _ = mitm.NewManager() + } + if !a.noCertCheck { + if !cert.IsCATrusted(mitm.CACertFile, cert.DefaultCertName) { + log.Warnf("MITM CA is not trusted - attempting automatic installation...") + if cert.InstallCA(mitm.CACertFile, cert.DefaultCertName) { + log.Infof("CA certificate installed. You may need to restart your browser.") + } else { + log.Errorf("Auto-install failed. Run with --install-cert or install ca/ca.crt manually.") + } + } else { + log.Infof("MITM CA is already trusted.") + } + } + + lanSharing := cfg.GetBool("lan_sharing", false) + listenHost := cfg.GetString("listen_host", "127.0.0.1") + if lanSharing && listenHost == "127.0.0.1" { + cfg.Set("listen_host", "0.0.0.0") + listenHost = "0.0.0.0" + log.Infof("LAN sharing enabled - listening on all interfaces") + } + lanMode := lanSharing || listenHost == "0.0.0.0" || listenHost == "::" + if lanMode { + var socksPort *int + if cfg.GetBool("socks5_enabled", true) { + p := cfg.GetInt("socks5_port", 1080) + socksPort = &p + } + lan.LogLANAccess(cfg.GetInt("listen_port", 8080), socksPort) + } + + server, err := proxy.NewServer(cfg) + if err != nil { + return err + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + go func() { + sig := <-signals + fmt.Fprintf(os.Stderr, "\nReceived %v, shutting down...\n", sig) + signal.Stop(signals) + cancel() + + go func() { + time.Sleep(3 * time.Second) + fmt.Fprintf(os.Stderr, "Force exit after timeout\n") + os.Exit(1) + }() + }() + + err = server.Start(ctx) + if err != nil && !errors.Is(err, context.Canceled) { + return err + } + log.Infof("Stopped") + return nil +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func isTTY(f *os.File) bool { + info, err := f.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func envOr(name, fallback string) string { + if v := os.Getenv(name); v != "" { + return v + } + return fallback +} + +func exeDir() string { + exe, _ := os.Executable() + return filepath.Dir(exe) +} diff --git a/config.json b/config.json new file mode 100644 index 0000000..59cee70 --- /dev/null +++ b/config.json @@ -0,0 +1,23 @@ +{ + "auth_key": "m97791182k", + "chunked_download_chunk_size": 524288, + "chunked_download_max_chunks": 256, + "chunked_download_max_parallel": 8, + "chunked_download_min_size": 5242880, + "front_domain": "www.google.com", + "google_ip": "216.239.38.120", + "hosts": {}, + "lan_sharing": true, + "listen_host": "0.0.0.0", + "listen_port": 8085, + "log_level": "INFO", + "max_response_body_bytes": 209715200, + "mode": "apps_script", + "relay_timeout": 25, + "script_id": "AKfycbyhilRCuPtX9UtaLGK55m4HcoGPUQA7sB7OQdJoeLWPNU0ifZKmy6cWas7x2NHnF3_bQw", + "socks5_enabled": false, + "socks5_port": 1080, + "tcp_connect_timeout": 10, + "tls_connect_timeout": 15, + "verify_ssl": true +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..b02323d --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/denuitt1/mhr-cfw + +go 1.22 + +require ( + github.com/andybalholm/brotli v1.1.0 + github.com/klauspost/compress v1.17.9 + golang.org/x/net v0.33.0 +) + +require golang.org/x/text v0.21.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a998fb6 --- /dev/null +++ b/go.sum @@ -0,0 +1,8 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= diff --git a/internal/cert/installer.go b/internal/cert/installer.go new file mode 100644 index 0000000..515f08e --- /dev/null +++ b/internal/cert/installer.go @@ -0,0 +1,326 @@ +package cert + +import ( + "bytes" + "crypto/sha1" + "crypto/x509" + "encoding/hex" + "encoding/pem" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + + "github.com/denuitt1/mhr-cfw/internal/logging" +) + +const DefaultCertName = "mhr-cfw" + +var log = logging.Get("Cert") + +func InstallCA(certPath, certName string) bool { + if _, err := os.Stat(certPath); err != nil { + log.Errorf("Certificate file not found: %s", certPath) + return false + } + switch runtime.GOOS { + case "windows": + ok := installWindows(certPath) + installFirefox(certPath, certName) + return ok + case "darwin": + ok := installMacOS(certPath) + installFirefox(certPath, certName) + return ok + case "linux": + ok := installLinux(certPath, certName) + installFirefox(certPath, certName) + return ok + default: + log.Errorf("Unsupported platform: %s", runtime.GOOS) + return false + } +} + +func UninstallCA(certPath, certName string) bool { + switch runtime.GOOS { + case "windows": + ok := uninstallWindows(certPath, certName) + uninstallFirefox(certName) + return ok + case "darwin": + ok := uninstallMacOS(certName) + uninstallFirefox(certName) + return ok + case "linux": + ok := uninstallLinux(certPath, certName) + uninstallFirefox(certName) + return ok + default: + log.Errorf("Unsupported platform: %s", runtime.GOOS) + return false + } +} + +func IsCATrusted(certPath, certName string) bool { + switch runtime.GOOS { + case "windows": + return isTrustedWindows(certPath) + case "darwin": + return isTrustedMacOS(certName) + case "linux": + return isTrustedLinux(certPath, certName) + default: + return false + } +} + +func run(cmd []string, check bool) ([]byte, error) { + c := exec.Command(cmd[0], cmd[1:]...) + var buf bytes.Buffer + c.Stdout = &buf + c.Stderr = &buf + err := c.Run() + if err != nil && check { + return buf.Bytes(), err + } + return buf.Bytes(), err +} + +func installWindows(certPath string) bool { + if _, err := run([]string{"certutil", "-addstore", "-user", "Root", certPath}, true); err == nil { + log.Infof("Certificate installed in Windows user Trusted Root store.") + return true + } + if _, err := run([]string{"certutil", "-addstore", "Root", certPath}, true); err == nil { + log.Infof("Certificate installed in Windows system Trusted Root store.") + return true + } + ps := "Import-Certificate -FilePath '" + certPath + "' -CertStoreLocation Cert:\\CurrentUser\\Root" + if _, err := run([]string{"powershell", "-NoProfile", "-Command", ps}, true); err == nil { + log.Infof("Certificate installed via PowerShell.") + return true + } + return false +} + +func isTrustedWindows(certPath string) bool { + out, err := run([]string{"certutil", "-user", "-store", "Root"}, true) + if err != nil { + return false + } + thumb := certThumbprint(certPath) + if thumb == "" { + return false + } + return strings.Contains(strings.ToUpper(string(out)), thumb) +} + +func uninstallWindows(certPath, certName string) bool { + thumb := certThumbprint(certPath) + target := certName + if thumb != "" { + target = thumb + } + if _, err := run([]string{"certutil", "-delstore", "-user", "Root", target}, true); err == nil { + log.Infof("Certificate removed from Windows user Trusted Root store.") + return true + } + if _, err := run([]string{"certutil", "-delstore", "Root", target}, true); err == nil { + log.Infof("Certificate removed from Windows system Trusted Root store.") + return true + } + return false +} + +func installMacOS(certPath string) bool { + login := filepath.Join(os.Getenv("HOME"), "Library/Keychains/login.keychain-db") + if _, err := run([]string{"security", "add-trusted-cert", "-d", "-r", "trustRoot", "-k", login, certPath}, true); err == nil { + log.Infof("Certificate installed in macOS login keychain.") + return true + } + if _, err := run([]string{"sudo", "security", "add-trusted-cert", "-d", "-r", "trustRoot", "-k", "/Library/Keychains/System.keychain", certPath}, true); err == nil { + log.Infof("Certificate installed in macOS system keychain.") + return true + } + return false +} + +func isTrustedMacOS(certName string) bool { + out, err := run([]string{"security", "find-certificate", "-a", "-c", certName}, true) + return err == nil && len(bytes.TrimSpace(out)) > 0 +} + +func uninstallMacOS(certName string) bool { + login := filepath.Join(os.Getenv("HOME"), "Library/Keychains/login.keychain-db") + if _, err := run([]string{"security", "delete-certificate", "-c", certName, login}, true); err == nil { + log.Infof("Certificate removed from macOS login keychain.") + return true + } + if _, err := run([]string{"sudo", "security", "delete-certificate", "-c", certName, "/Library/Keychains/System.keychain"}, true); err == nil { + log.Infof("Certificate removed from macOS system keychain.") + return true + } + return false +} + +func installLinux(certPath, certName string) bool { + distro := detectLinuxDistro() + log.Infof("Detected Linux distro family: %s", distro) + + switch distro { + case "debian": + dest := "/usr/local/share/ca-certificates/" + strings.ReplaceAll(certName, " ", "_") + ".crt" + if _, err := run([]string{"cp", certPath, dest}, true); err == nil { + _, _ = run([]string{"update-ca-certificates"}, true) + log.Infof("Certificate installed via update-ca-certificates.") + return true + } + case "rhel": + dest := "/etc/pki/ca-trust/source/anchors/" + strings.ReplaceAll(certName, " ", "_") + ".crt" + if _, err := run([]string{"cp", certPath, dest}, true); err == nil { + _, _ = run([]string{"update-ca-trust", "extract"}, true) + log.Infof("Certificate installed via update-ca-trust.") + return true + } + case "arch": + dest := "/etc/ca-certificates/trust-source/anchors/" + strings.ReplaceAll(certName, " ", "_") + ".crt" + if _, err := run([]string{"cp", certPath, dest}, true); err == nil { + _, _ = run([]string{"trust", "extract-compat"}, true) + log.Infof("Certificate installed via trust extract-compat.") + return true + } + } + log.Warnf("Unknown Linux distro. Manually install %s as a trusted root CA.", certPath) + return false +} + +func isTrustedLinux(certPath, certName string) bool { + target := strings.ReplaceAll(certName, " ", "_") + ".crt" + paths := []string{ + "/usr/local/share/ca-certificates/" + target, + "/etc/pki/ca-trust/source/anchors/" + target, + "/etc/ca-certificates/trust-source/anchors/" + target, + } + for _, p := range paths { + if _, err := os.Stat(p); err == nil { + return true + } + } + return false +} + +func uninstallLinux(certPath, certName string) bool { + distro := detectLinuxDistro() + log.Infof("Detected Linux distro family: %s", distro) + + switch distro { + case "debian": + dest := "/usr/local/share/ca-certificates/" + strings.ReplaceAll(certName, " ", "_") + ".crt" + _ = os.Remove(dest) + _, _ = run([]string{"update-ca-certificates"}, true) + log.Infof("Certificate removed via update-ca-certificates.") + return true + case "rhel": + dest := "/etc/pki/ca-trust/source/anchors/" + strings.ReplaceAll(certName, " ", "_") + ".crt" + _ = os.Remove(dest) + _, _ = run([]string{"update-ca-trust", "extract"}, true) + log.Infof("Certificate removed via update-ca-trust.") + return true + case "arch": + dest := "/etc/ca-certificates/trust-source/anchors/" + strings.ReplaceAll(certName, " ", "_") + ".crt" + _ = os.Remove(dest) + _, _ = run([]string{"trust", "extract-compat"}, true) + log.Infof("Certificate removed via trust extract-compat.") + return true + } + log.Warnf("Unknown Linux distro. Manually remove %s from trusted CAs.", certName) + return false +} + +func detectLinuxDistro() string { + if fileExists("/etc/debian_version") || fileExists("/etc/ubuntu") { + return "debian" + } + if fileExists("/etc/redhat-release") || fileExists("/etc/fedora-release") { + return "rhel" + } + if fileExists("/etc/arch-release") { + return "arch" + } + return "unknown" +} + +func installFirefox(certPath, certName string) { + if _, err := exec.LookPath("certutil"); err != nil { + return + } + profiles := firefoxProfiles() + for _, profile := range profiles { + db := "sql:" + profile + if !fileExists(filepath.Join(profile, "cert9.db")) { + db = "dbm:" + profile + } + _, _ = run([]string{"certutil", "-D", "-n", certName, "-d", db}, false) + _, _ = run([]string{"certutil", "-A", "-n", certName, "-t", "CT,,", "-i", certPath, "-d", db}, true) + } +} + +func uninstallFirefox(certName string) { + if _, err := exec.LookPath("certutil"); err != nil { + return + } + profiles := firefoxProfiles() + for _, profile := range profiles { + db := "sql:" + profile + if !fileExists(filepath.Join(profile, "cert9.db")) { + db = "dbm:" + profile + } + _, _ = run([]string{"certutil", "-D", "-n", certName, "-d", db}, false) + } +} + +func firefoxProfiles() []string { + var out []string + switch runtime.GOOS { + case "windows": + appdata := os.Getenv("APPDATA") + if appdata != "" { + out = append(out, glob(filepath.Join(appdata, "Mozilla", "Firefox", "Profiles", "*"))...) + } + case "darwin": + out = append(out, glob(filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "Firefox", "Profiles", "*"))...) + default: + out = append(out, glob(filepath.Join(os.Getenv("HOME"), ".mozilla", "firefox", "*.default*"))...) + out = append(out, glob(filepath.Join(os.Getenv("HOME"), ".mozilla", "firefox", "*.release*"))...) + } + return out +} + +func glob(pattern string) []string { + m, _ := filepath.Glob(pattern) + return m +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func certThumbprint(certPath string) string { + raw, err := os.ReadFile(certPath) + if err != nil { + return "" + } + block, _ := pem.Decode(raw) + if block == nil { + return "" + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return "" + } + sum := sha1.Sum(cert.Raw) + return strings.ToUpper(hex.EncodeToString(sum[:])) +} diff --git a/internal/codec/codec.go b/internal/codec/codec.go new file mode 100644 index 0000000..ee4e1c8 --- /dev/null +++ b/internal/codec/codec.go @@ -0,0 +1,93 @@ +package codec + +import ( + "bytes" + "compress/gzip" + "compress/zlib" + "io" + "strings" + + "github.com/andybalholm/brotli" + "github.com/klauspost/compress/zstd" +) + +var ( + hasBrotli = true + hasZstd = true +) + +func SupportedEncodings() string { + codecs := []string{"gzip", "deflate"} + if hasBrotli { + codecs = append(codecs, "br") + } + if hasZstd { + codecs = append(codecs, "zstd") + } + return strings.Join(codecs, ", ") +} + +func Decode(body []byte, encoding string) []byte { + if len(body) == 0 { + return body + } + enc := strings.TrimSpace(strings.ToLower(encoding)) + if enc == "" || enc == "identity" { + return body + } + if strings.Contains(enc, ",") { + parts := strings.Split(enc, ",") + for i := len(parts) - 1; i >= 0; i-- { + body = Decode(body, strings.TrimSpace(parts[i])) + } + return body + } + switch enc { + case "gzip": + r, err := gzip.NewReader(bytes.NewReader(body)) + if err != nil { + return body + } + defer r.Close() + out, err := io.ReadAll(r) + if err != nil { + return body + } + return out + case "deflate": + r, err := zlib.NewReader(bytes.NewReader(body)) + if err == nil { + defer r.Close() + out, err := io.ReadAll(r) + if err == nil { + return out + } + } + return body + case "br": + if !hasBrotli { + return body + } + r := brotli.NewReader(bytes.NewReader(body)) + out, err := io.ReadAll(r) + if err != nil { + return body + } + return out + case "zstd": + if !hasZstd { + return body + } + r, err := zstd.NewReader(bytes.NewReader(body)) + if err != nil { + return body + } + defer r.Close() + out, err := io.ReadAll(r) + if err != nil { + return body + } + return out + } + return body +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..6d8f210 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,155 @@ +package config + +import ( + "encoding/json" + "errors" + "os" + "strconv" + "strings" +) + +type Config map[string]any + +func Load(path string) (Config, error) { + raw, err := os.ReadFile(path) + if err != nil { + return nil, err + } + var cfg map[string]any + if err := json.Unmarshal(raw, &cfg); err != nil { + return nil, err + } + return Config(cfg), nil +} + +func (c Config) Set(key string, value any) { + c[key] = value +} + +func (c Config) GetString(key, def string) string { + if v, ok := c[key]; ok { + switch t := v.(type) { + case string: + return t + case []byte: + return string(t) + case float64: + return strconv.FormatInt(int64(t), 10) + case int: + return strconv.Itoa(t) + case bool: + if t { + return "true" + } + return "false" + } + } + return def +} + +func (c Config) GetInt(key string, def int) int { + if v, ok := c[key]; ok { + switch t := v.(type) { + case float64: + return int(t) + case int: + return t + case string: + if i, err := strconv.Atoi(strings.TrimSpace(t)); err == nil { + return i + } + } + } + return def +} + +func (c Config) GetBool(key string, def bool) bool { + if v, ok := c[key]; ok { + switch t := v.(type) { + case bool: + return t + case string: + s := strings.TrimSpace(strings.ToLower(t)) + return s == "1" || s == "true" || s == "yes" || s == "y" + case float64: + return t != 0 + } + } + return def +} + +func (c Config) GetStringSlice(key string) []string { + v, ok := c[key] + if !ok { + return nil + } + switch t := v.(type) { + case []string: + return t + case []any: + out := make([]string, 0, len(t)) + for _, item := range t { + if s, ok := item.(string); ok && strings.TrimSpace(s) != "" { + out = append(out, strings.TrimSpace(s)) + } + } + return out + case string: + s := strings.TrimSpace(t) + if s == "" { + return nil + } + return []string{s} + } + return nil +} + +func (c Config) GetStringMap(key string) map[string]string { + out := map[string]string{} + v, ok := c[key] + if !ok { + return out + } + switch t := v.(type) { + case map[string]string: + return t + case map[string]any: + for k, v := range t { + if s, ok := v.(string); ok { + out[strings.ToLower(strings.TrimSpace(k))] = strings.TrimSpace(s) + } + } + } + return out +} + +func (c Config) GetScriptIDs() []string { + ids := c.GetStringSlice("script_ids") + if len(ids) > 0 { + return ids + } + if s := c.GetString("script_id", ""); s != "" { + return []string{s} + } + return nil +} + +func (c Config) GetScriptID() string { + ids := c.GetScriptIDs() + if len(ids) > 0 { + return ids[0] + } + return "" +} + +func ToInt(v string, def int) int { + i, err := strconv.Atoi(strings.TrimSpace(v)) + if err != nil { + return def + } + return i +} + +func ErrMissing(key string) error { + return errors.New("missing required config key: " + key) +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go new file mode 100644 index 0000000..d826eb2 --- /dev/null +++ b/internal/constants/constants.go @@ -0,0 +1,191 @@ +package constants + +const Version = "1.1.0" + +const ( + MaxRequestBodyBytes = 100 * 1024 * 1024 + MaxResponseBodyBytes = 200 * 1024 * 1024 + MaxHeaderBytes = 64 * 1024 +) + +const ( + ClientIdleTimeout = 120 + RelayTimeout = 25 + TLSConnectTimeout = 15 + TCPConnectTimeout = 10 +) + +const ( + GoogleScannerTimeout = 4 + GoogleScannerConcurrency = 8 +) + +var CandidateIPs = []string{ + "216.239.32.120", + "216.239.34.120", + "216.239.36.120", + "216.239.38.120", + "142.250.80.142", + "142.250.80.138", + "142.250.179.110", + "142.250.185.110", + "142.250.184.206", + "142.250.190.238", + "142.250.191.78", + "172.217.1.206", + "172.217.14.206", + "172.217.16.142", + "172.217.22.174", + "172.217.164.110", + "172.217.168.206", + "172.217.169.206", + "34.107.221.82", + "142.251.32.110", + "142.251.33.110", + "142.251.46.206", + "142.251.46.238", + "142.250.80.170", + "142.250.72.206", + "142.250.64.206", + "142.250.72.110", +} + +const ( + CacheMaxMB = 50 + CacheTTLStaticLong = 3600 + CacheTTLStaticMed = 1800 + CacheTTLMax = 86400 +) + +const ( + PoolMax = 50 + PoolMinIdle = 15 + ConnTTL = 45.0 + SemaphoreMax = 50 + WarmPoolCount = 30 +) + +const ( + BatchWindowMicro = 0.005 + BatchWindowMacro = 0.050 + BatchMax = 50 +) + +const ( + ScriptBlacklistTTL = 600.0 +) + +var FrontSNIPoolGoogle = []string{ + "www.google.com", + "mail.google.com", + "accounts.google.com", +} + +const ( + StatsLogInterval = 300.0 + StatsLogTopN = 10 +) + +var GoogleDirectExactExclude = map[string]struct{}{ + "gemini.google.com": {}, + "aistudio.google.com": {}, + "notebooklm.google.com": {}, + "labs.google.com": {}, + "meet.google.com": {}, + "accounts.google.com": {}, + "ogs.google.com": {}, + "mail.google.com": {}, + "calendar.google.com": {}, + "drive.google.com": {}, + "docs.google.com": {}, + "chat.google.com": {}, + "photos.google.com": {}, + "maps.google.com": {}, + "myaccount.google.com": {}, + "contacts.google.com": {}, + "classroom.google.com": {}, + "keep.google.com": {}, + "play.google.com": {}, + "translate.google.com": {}, + "assistant.google.com": {}, + "lens.google.com": {}, +} + +var GoogleDirectSuffixExclude = []string{ + ".meet.google.com", +} + +var GoogleDirectAllowExact = map[string]struct{}{ + "www.google.com": {}, + "google.com": {}, + "safebrowsing.google.com": {}, +} + +var GoogleDirectAllowSuffixes = []string{} + +var GoogleOwnedSuffixes = []string{ + ".google.com", ".google.co", + ".googleapis.com", ".gstatic.com", + ".googleusercontent.com", +} +var GoogleOwnedExact = map[string]struct{}{ + "google.com": {}, + "gstatic.com": {}, + "googleapis.com": {}, +} + +var SNIRewriteSuffixes = []string{ + "youtube.com", + "youtu.be", + "youtube-nocookie.com", + "ytimg.com", + "ggpht.com", + "gvt1.com", + "gvt2.com", + "doubleclick.net", + "googlesyndication.com", + "googleadservices.com", + "google-analytics.com", + "googletagmanager.com", + "googletagservices.com", + "fonts.googleapis.com", + "script.google.com", +} + +var TraceHostSuffixes = []string{ + "chatgpt.com", + "openai.com", + "gemini.google.com", + "google.com", + "cloudflare.com", + "challenges.cloudflare.com", + "turnstile", +} + +var StaticExts = []string{ + ".css", ".js", ".mjs", ".woff", ".woff2", ".ttf", ".eot", + ".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg", ".ico", + ".mp3", ".mp4", ".webm", ".wasm", ".avif", +} + +var LargeFileExts = map[string]struct{}{ + ".bin": {}, + ".zip": {}, ".tar": {}, ".gz": {}, ".bz2": {}, ".xz": {}, ".7z": {}, ".rar": {}, + ".exe": {}, ".msi": {}, ".dmg": {}, ".deb": {}, ".rpm": {}, ".apk": {}, + ".iso": {}, ".img": {}, + ".mp4": {}, ".mkv": {}, ".avi": {}, ".mov": {}, ".webm": {}, + ".mp3": {}, ".flac": {}, ".wav": {}, ".aac": {}, + ".pdf": {}, ".doc": {}, ".docx": {}, ".ppt": {}, ".pptx": {}, + ".wasm": {}, +} + +var StatefulHeaderNames = []string{ + "cookie", "authorization", "proxy-authorization", + "origin", "referer", "if-none-match", "if-modified-since", + "cache-control", "pragma", +} + +var UncacheableHeaderNames = []string{ + "cookie", "authorization", "proxy-authorization", "range", + "if-none-match", "if-modified-since", "cache-control", "pragma", +} diff --git a/internal/fronter/fronter.go b/internal/fronter/fronter.go new file mode 100644 index 0000000..048e357 --- /dev/null +++ b/internal/fronter/fronter.go @@ -0,0 +1,771 @@ +package fronter + +import ( + "bufio" + "bytes" + "context" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/url" + "regexp" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/denuitt1/mhr-cfw/internal/codec" + "github.com/denuitt1/mhr-cfw/internal/config" + "github.com/denuitt1/mhr-cfw/internal/constants" + "github.com/denuitt1/mhr-cfw/internal/h2" + "github.com/denuitt1/mhr-cfw/internal/logging" +) + +var log = logging.Get("Fronter") + +type HostStat struct { + Requests int + CacheHits int + Bytes int + TotalLatencyNs int64 + Errors int +} + +type DomainFronter struct { + connectHost string + sniHost string + sniHosts []string + sniIdx int + httpHost string + scriptIDs []string + scriptIdx int + devAvail bool + + parallelRelay int + sidBlacklist map[string]time.Time + blacklistTTL time.Duration + + perSite map[string]*HostStat + + authKey string + verifySSL bool + relayTO time.Duration + tlsTO time.Duration + maxResp int + + h2 *h2.Transport + + poolMu sync.Mutex + pool []pooledConn + + batchMu sync.Mutex + batchPending []batchItem + batchTimer *time.Timer + + coalesceMu sync.Mutex + coalesce map[string][]chan []byte + + statsStop chan struct{} +} + +type pooledConn struct { + conn net.Conn + created time.Time +} + +type batchItem struct { + payload map[string]any + respCh chan []byte +} + +func New(cfg config.Config) *DomainFronter { + frontDomain := cfg.GetString("front_domain", "www.google.com") + fronts := buildSNIPool(frontDomain, cfg.GetStringSlice("front_domains")) + ids := cfg.GetScriptIDs() + if len(ids) == 0 { + ids = []string{cfg.GetString("script_id", "")} + } + parallel := cfg.GetInt("parallel_relay", 1) + if parallel < 1 { + parallel = 1 + } + if parallel > len(ids) { + parallel = len(ids) + } + + f := &DomainFronter{ + connectHost: cfg.GetString("google_ip", "216.239.38.120"), + sniHost: frontDomain, + sniHosts: fronts, + httpHost: "script.google.com", + scriptIDs: ids, + sidBlacklist: map[string]time.Time{}, + blacklistTTL: time.Duration(constants.ScriptBlacklistTTL * float64(time.Second)), + perSite: map[string]*HostStat{}, + authKey: cfg.GetString("auth_key", ""), + verifySSL: cfg.GetBool("verify_ssl", true), + relayTO: time.Duration(cfg.GetInt("relay_timeout", constants.RelayTimeout)) * time.Second, + tlsTO: time.Duration(cfg.GetInt("tls_connect_timeout", constants.TLSConnectTimeout)) * time.Second, + maxResp: cfg.GetInt("max_response_body_bytes", constants.MaxResponseBodyBytes), + parallelRelay: parallel, + coalesce: map[string][]chan []byte{}, + statsStop: make(chan struct{}), + } + + if len(fronts) > 1 { + log.Infof("SNI rotation pool (%d): %s", len(fronts), strings.Join(fronts, ", ")) + } + if parallel > 1 { + log.Infof("Fan-out relay: %d parallel Apps Script instances per request", parallel) + } + log.Infof("Response codecs: %s", codec.SupportedEncodings()) + + f.h2 = h2.New(f.connectHost, f.sniHosts, f.verifySSL) + go f.statsLoop() + return f +} + +func buildSNIPool(frontDomain string, overrides []string) []string { + if len(overrides) > 0 { + seen := map[string]bool{} + out := []string{} + for _, item := range overrides { + host := strings.ToLower(strings.TrimSuffix(strings.TrimSpace(item), ".")) + if host != "" && !seen[host] { + seen[host] = true + out = append(out, host) + } + } + if len(out) > 0 { + return out + } + } + fd := strings.ToLower(strings.TrimSuffix(frontDomain, ".")) + if strings.HasSuffix(fd, ".google.com") || fd == "google.com" { + pool := []string{fd} + for _, h := range constants.FrontSNIPoolGoogle { + if h != fd { + pool = append(pool, h) + } + } + return pool + } + if fd == "" { + return []string{"www.google.com"} + } + return []string{fd} +} + +func (f *DomainFronter) Close() error { + close(f.statsStop) + if f.h2 != nil { + _ = f.h2.Close() + } + f.poolMu.Lock() + for _, pc := range f.pool { + _ = pc.conn.Close() + } + f.pool = nil + f.poolMu.Unlock() + return nil +} + +func (f *DomainFronter) Relay(method, urlStr string, headers map[string]string, body []byte) []byte { + payload := f.buildPayload(method, urlStr, headers, body) + start := time.Now() + err := false + var raw []byte + defer func() { + f.recordSite(urlStr, len(raw), time.Since(start), err) + }() + + if f.isStatefulRequest(method, urlStr, headers, body) { + resp, e := f.relaySingle(payload) + if e != nil { + err = true + return f.errorResponse(502, e.Error()) + } + return resp + } + + key := f.coalesceKey(urlStr, headers) + if strings.ToUpper(method) == "GET" && len(body) == 0 { + if v := headerValue(headers, "range"); v == "" { + if resp, ok := f.tryCoalesce(key, payload); ok { + return resp + } + } + } + + resp, e := f.batchSubmit(payload) + if e != nil { + err = true + return f.errorResponse(502, e.Error()) + } + return resp +} + +func (f *DomainFronter) tryCoalesce(key string, payload map[string]any) ([]byte, bool) { + f.coalesceMu.Lock() + if waiters, ok := f.coalesce[key]; ok { + ch := make(chan []byte, 1) + f.coalesce[key] = append(waiters, ch) + f.coalesceMu.Unlock() + resp := <-ch + return resp, true + } + f.coalesce[key] = []chan []byte{} + f.coalesceMu.Unlock() + + resp, err := f.batchSubmit(payload) + if err != nil { + resp = f.errorResponse(502, err.Error()) + } + + f.coalesceMu.Lock() + waiters := f.coalesce[key] + delete(f.coalesce, key) + f.coalesceMu.Unlock() + for _, ch := range waiters { + ch <- resp + } + return resp, true +} + +func (f *DomainFronter) batchSubmit(payload map[string]any) ([]byte, error) { + respCh := make(chan []byte, 1) + item := batchItem{payload: payload, respCh: respCh} + + f.batchMu.Lock() + f.batchPending = append(f.batchPending, item) + if len(f.batchPending) >= constants.BatchMax { + pending := f.batchPending + f.batchPending = nil + if f.batchTimer != nil { + f.batchTimer.Stop() + f.batchTimer = nil + } + f.batchMu.Unlock() + go f.flushBatch(pending) + return <-respCh, nil + } + if f.batchTimer == nil { + f.batchTimer = time.AfterFunc(time.Duration(constants.BatchWindowMicro*float64(time.Second)), func() { + f.batchMu.Lock() + pending := f.batchPending + f.batchPending = nil + f.batchTimer = nil + f.batchMu.Unlock() + if len(pending) > 0 { + f.flushBatch(pending) + } + }) + } + f.batchMu.Unlock() + return <-respCh, nil +} + +func (f *DomainFronter) flushBatch(batch []batchItem) { + if len(batch) == 1 { + resp, err := f.relaySingle(batch[0].payload) + if err != nil { + resp = f.errorResponse(502, err.Error()) + } + batch[0].respCh <- resp + return + } + results, err := f.relayBatch(batch) + if err != nil { + for _, item := range batch { + item.respCh <- f.errorResponse(502, err.Error()) + } + return + } + for i, item := range batch { + item.respCh <- results[i] + } +} + +func (f *DomainFronter) relaySingle(payload map[string]any) ([]byte, error) { + full := map[string]any{} + for k, v := range payload { + full[k] = v + } + full["k"] = f.authKey + jsonBody, _ := json.Marshal(full) + path := f.execPath(payload["u"]) + + _, _, body, err := f.h2.Request(context.Background(), "POST", path, f.httpHost, map[string]string{"content-type": "application/json"}, jsonBody, f.relayTO) + if err == nil { + return f.parseRelayResponse(body), nil + } + + resp, err := f.relayHTTP1(path, jsonBody) + if err != nil { + return nil, err + } + return f.parseRelayResponse(resp), nil +} + +func (f *DomainFronter) relayBatch(batch []batchItem) ([][]byte, error) { + payloads := []map[string]any{} + for _, item := range batch { + payloads = append(payloads, item.payload) + } + full := map[string]any{ + "k": f.authKey, + "q": payloads, + } + jsonBody, _ := json.Marshal(full) + path := f.execPath(payloads[0]["u"]) + + _, _, body, err := f.h2.Request(context.Background(), "POST", path, f.httpHost, map[string]string{"content-type": "application/json"}, jsonBody, 30*time.Second) + if err == nil { + return f.parseBatchBody(body, len(batch)) + } + resp, err := f.relayHTTP1(path, jsonBody) + if err != nil { + return nil, err + } + return f.parseBatchBody(resp, len(batch)) +} + +func (f *DomainFronter) relayHTTP1(path string, body []byte) ([]byte, error) { + conn, err := f.acquire() + if err != nil { + return nil, err + } + defer f.release(conn) + + req := fmt.Sprintf("POST %s HTTP/1.1\r\nHost: %s\r\nContent-Type: application/json\r\nContent-Length: %d\r\nAccept-Encoding: gzip\r\nConnection: keep-alive\r\n\r\n", path, f.httpHost, len(body)) + if _, err := conn.Write([]byte(req)); err != nil { + return nil, err + } + if _, err := conn.Write(body); err != nil { + return nil, err + } + + status, headers, respBody, err := readHTTPResponse(conn, f.maxResp) + if err != nil { + return nil, err + } + + if status >= 300 && status < 400 { + loc := headers["location"] + if loc != "" { + parsed, _ := url.Parse(loc) + rpath := parsed.Path + if parsed.RawQuery != "" { + rpath += "?" + parsed.RawQuery + } + return f.relayHTTP1(rpath, body) + } + } + return respBody, nil +} + +func readHTTPResponse(conn net.Conn, maxBody int) (int, map[string]string, []byte, error) { + reader := bufio.NewReader(conn) + statusLine, err := reader.ReadString('\n') + if err != nil { + return 0, nil, nil, err + } + status := 0 + if m := regexp.MustCompile(`\d{3}`).FindString(statusLine); m != "" { + status, _ = strconv.Atoi(m) + } + headers := map[string]string{} + for { + line, err := reader.ReadString('\n') + if err != nil { + return status, headers, nil, err + } + line = strings.TrimRight(line, "\r\n") + if line == "" { + break + } + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + headers[strings.ToLower(strings.TrimSpace(parts[0]))] = strings.TrimSpace(parts[1]) + } + } + + cl := 0 + if v := headers["content-length"]; v != "" { + cl, _ = strconv.Atoi(v) + } + body := []byte{} + if cl > 0 { + if cl > maxBody { + return status, headers, nil, errors.New("response exceeds cap") + } + buf := make([]byte, cl) + _, err = io.ReadFull(reader, buf) + if err != nil { + return status, headers, nil, err + } + body = buf + } else { + buf, _ := io.ReadAll(reader) + body = buf + } + if enc := headers["content-encoding"]; enc != "" { + body = codec.Decode(body, enc) + } + return status, headers, body, nil +} + +func (f *DomainFronter) acquire() (net.Conn, error) { + f.poolMu.Lock() + for len(f.pool) > 0 { + pc := f.pool[len(f.pool)-1] + f.pool = f.pool[:len(f.pool)-1] + if time.Since(pc.created) < time.Duration(constants.ConnTTL*float64(time.Second)) { + f.poolMu.Unlock() + return pc.conn, nil + } + _ = pc.conn.Close() + } + f.poolMu.Unlock() + + dialer := &net.Dialer{Timeout: f.tlsTO} + conn, err := dialer.Dial("tcp", net.JoinHostPort(f.connectHost, "443")) + if err != nil { + return nil, err + } + if tcp, ok := conn.(*net.TCPConn); ok { + _ = tcp.SetNoDelay(true) + } + tlsConn := tls.Client(conn, &tls.Config{ServerName: f.nextSNI(), InsecureSkipVerify: !f.verifySSL}) + if err := tlsConn.Handshake(); err != nil { + _ = conn.Close() + return nil, err + } + return tlsConn, nil +} + +func (f *DomainFronter) release(conn net.Conn) { + f.poolMu.Lock() + defer f.poolMu.Unlock() + if len(f.pool) >= constants.PoolMax { + _ = conn.Close() + return + } + f.pool = append(f.pool, pooledConn{conn: conn, created: time.Now()}) +} + +func (f *DomainFronter) nextSNI() string { + sni := f.sniHosts[f.sniIdx%len(f.sniHosts)] + f.sniIdx++ + return sni +} + +func (f *DomainFronter) execPath(urlOrHost any) string { + sid := f.scriptIDForKey(hostKey(fmt.Sprint(urlOrHost))) + if f.devAvail { + return "/macros/s/" + sid + "/dev" + } + return "/macros/s/" + sid + "/exec" +} + +func hostKey(urlOrHost string) string { + if urlOrHost == "" { + return "" + } + if strings.Contains(urlOrHost, "://") { + parsed, err := url.Parse(urlOrHost) + if err == nil { + return strings.ToLower(strings.TrimSuffix(parsed.Hostname(), ".")) + } + } + return strings.ToLower(strings.TrimSuffix(urlOrHost, ".")) +} + +func (f *DomainFronter) scriptIDForKey(key string) string { + if len(f.scriptIDs) == 1 { + return f.scriptIDs[0] + } + if key == "" { + f.scriptIdx = (f.scriptIdx + 1) % len(f.scriptIDs) + return f.scriptIDs[f.scriptIdx] + } + h := sha1.Sum([]byte(key)) + idx := int(h[0]) % len(f.scriptIDs) + return f.scriptIDs[idx] +} + +func (f *DomainFronter) buildPayload(method, urlStr string, headers map[string]string, body []byte) map[string]any { + p := map[string]any{ + "m": method, + "u": urlStr, + "r": false, + } + if headers != nil { + p["h"] = headers + } + if len(body) > 0 { + p["b"] = base64.StdEncoding.EncodeToString(body) + if ct := headerValue(headers, "content-type"); ct != "" { + p["ct"] = ct + } + } + return p +} + +func (f *DomainFronter) parseRelayResponse(body []byte) []byte { + text := strings.TrimSpace(string(body)) + if text == "" { + return f.errorResponse(502, "Empty response from relay") + } + var data map[string]any + if err := json.Unmarshal([]byte(text), &data); err != nil { + m := regexp.MustCompile(`\{.*\}`).FindString(text) + if m == "" { + return f.errorResponse(502, "No JSON: "+truncate(text, 200)) + } + if err := json.Unmarshal([]byte(m), &data); err != nil { + return f.errorResponse(502, "Bad JSON: "+truncate(text, 200)) + } + } + return f.parseRelayJSON(data) +} + +func (f *DomainFronter) errorResponse(status int, message string) []byte { + body := fmt.Sprintf("
%s
", status, message) + resp := fmt.Sprintf("HTTP/1.1 %d Error\r\nContent-Type: text/html\r\nContent-Length: %d\r\n\r\n%s", status, len(body), body) + return []byte(resp) +} + +func (f *DomainFronter) parseRelayJSON(data map[string]any) []byte { + if e, ok := data["e"]; ok { + return f.errorResponse(502, fmt.Sprintf("Relay error: %v", e)) + } + status := intVal(data["s"], 200) + headers := map[string]any{} + if h, ok := data["h"].(map[string]any); ok { + headers = h + } + bodyRaw := "" + if b, ok := data["b"].(string); ok { + bodyRaw = b + } + body, _ := base64.StdEncoding.DecodeString(bodyRaw) + if len(body) > f.maxResp { + return f.errorResponse(502, "Relay response exceeds cap") + } + statusText := "OK" + switch status { + case 206: + statusText = "Partial Content" + case 301: + statusText = "Moved" + case 302: + statusText = "Found" + case 304: + statusText = "Not Modified" + case 400: + statusText = "Bad Request" + case 403: + statusText = "Forbidden" + case 404: + statusText = "Not Found" + case 500: + statusText = "Internal Server Error" + } + + buf := bytes.NewBufferString(fmt.Sprintf("HTTP/1.1 %d %s\r\n", status, statusText)) + skip := map[string]bool{ + "transfer-encoding": true, + "connection": true, + "keep-alive": true, + "content-length": true, + "content-encoding": true, + } + for k, v := range headers { + lk := strings.ToLower(k) + if skip[lk] { + continue + } + switch val := v.(type) { + case []any: + for _, item := range val { + buf.WriteString(fmt.Sprintf("%s: %v\r\n", k, item)) + } + default: + buf.WriteString(fmt.Sprintf("%s: %v\r\n", k, val)) + } + } + buf.WriteString(fmt.Sprintf("Content-Length: %d\r\n\r\n", len(body))) + buf.Write(body) + return buf.Bytes() +} + +func (f *DomainFronter) parseBatchBody(body []byte, expected int) ([][]byte, error) { + text := strings.TrimSpace(string(body)) + var data map[string]any + if err := json.Unmarshal([]byte(text), &data); err != nil { + return nil, err + } + if e, ok := data["e"]; ok { + return nil, fmt.Errorf("Batch error: %v", e) + } + arr, ok := data["q"].([]any) + if !ok || len(arr) != expected { + return nil, errors.New("batch size mismatch") + } + results := make([][]byte, 0, len(arr)) + for _, item := range arr { + if obj, ok := item.(map[string]any); ok { + results = append(results, f.parseRelayJSON(obj)) + } + } + return results, nil +} + +func (f *DomainFronter) isStatefulRequest(method, urlStr string, headers map[string]string, body []byte) bool { + method = strings.ToUpper(method) + if method != "GET" && method != "HEAD" { + return true + } + if len(body) > 0 { + return true + } + for _, name := range constants.StatefulHeaderNames { + if headerValue(headers, name) != "" { + return true + } + } + accept := strings.ToLower(headerValue(headers, "accept")) + if strings.Contains(accept, "text/html") || strings.Contains(accept, "application/json") { + return true + } + fetchMode := strings.ToLower(headerValue(headers, "sec-fetch-mode")) + if fetchMode == "navigate" || fetchMode == "cors" { + return true + } + return !isStaticAssetURL(urlStr) +} + +func isStaticAssetURL(urlStr string) bool { + parsed, err := url.Parse(urlStr) + if err != nil { + return false + } + path := strings.ToLower(parsed.Path) + for _, ext := range constants.StaticExts { + if strings.HasSuffix(path, ext) { + return true + } + } + return false +} + +func (f *DomainFronter) coalesceKey(urlStr string, headers map[string]string) string { + key := []string{urlStr} + if headers != nil { + for _, name := range []string{"accept", "accept-language", "user-agent", "sec-fetch-dest", "sec-fetch-mode", "sec-fetch-site"} { + if v := headerValue(headers, name); v != "" { + key = append(key, name+"="+v) + } + } + } + return strings.Join(key, "\n") +} + +func (f *DomainFronter) recordSite(urlStr string, bytes int, latency time.Duration, errored bool) { + host := hostKey(urlStr) + if host == "" { + return + } + stat, ok := f.perSite[host] + if !ok { + stat = &HostStat{} + f.perSite[host] = stat + } + stat.Requests++ + stat.Bytes += bytes + stat.TotalLatencyNs += latency.Nanoseconds() + if errored { + stat.Errors++ + } +} + +func (f *DomainFronter) statsLoop() { + ticker := time.NewTicker(time.Duration(constants.StatsLogInterval) * time.Second) + defer ticker.Stop() + for { + select { + case <-f.statsStop: + return + case <-ticker.C: + f.logStats() + } + } +} + +func (f *DomainFronter) logStats() { + if len(f.perSite) == 0 { + return + } + type statEntry struct { + host string + stat *HostStat + } + entries := make([]statEntry, 0, len(f.perSite)) + for host, stat := range f.perSite { + entries = append(entries, statEntry{host: host, stat: stat}) + } + sort.Slice(entries, func(i, j int) bool { + return entries[i].stat.Bytes > entries[j].stat.Bytes + }) + count := constants.StatsLogTopN + if count > len(entries) { + count = len(entries) + } + log.Debugf("-- Per-host stats (top %d by bytes) --", count) + for i := 0; i < count; i++ { + e := entries[i] + avgLatency := time.Duration(0) + if e.stat.Requests > 0 { + avgLatency = time.Duration(e.stat.TotalLatencyNs / int64(e.stat.Requests)) + } + log.Debugf(" %s: %d reqs, %.2fMB, %s avg, %d errs", + e.host, e.stat.Requests, float64(e.stat.Bytes)/1024/1024, avgLatency, e.stat.Errors) + } +} + +func headerValue(headers map[string]string, name string) string { + for k, v := range headers { + if strings.ToLower(k) == name { + return v + } + } + return "" +} + +func truncate(s string, max int) string { + if len(s) <= max { + return s + } + return s[:max] +} + +func intVal(v any, def int) int { + switch t := v.(type) { + case float64: + return int(t) + case int: + return t + case string: + if i, err := strconv.Atoi(t); err == nil { + return i + } + } + return def +} diff --git a/internal/h2/h2_transport.go b/internal/h2/h2_transport.go new file mode 100644 index 0000000..aea4cb3 --- /dev/null +++ b/internal/h2/h2_transport.go @@ -0,0 +1,144 @@ +package h2 + +import ( + "bytes" + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + "golang.org/x/net/http2" + + "github.com/denuitt1/mhr-cfw/internal/codec" + "github.com/denuitt1/mhr-cfw/internal/logging" +) + +var log = logging.Get("H2") + +type Transport struct { + connectHost string + verifySSL bool + sniHosts []string + sniIdx uint32 + + mu sync.Mutex + h2 *http2.Transport + client *http.Client +} + +func New(connectHost string, sniHosts []string, verifySSL bool) *Transport { + if len(sniHosts) == 0 { + sniHosts = []string{"www.google.com"} + } + return &Transport{ + connectHost: connectHost, + verifySSL: verifySSL, + sniHosts: sniHosts, + } +} + +func (t *Transport) ensure() { + t.mu.Lock() + defer t.mu.Unlock() + if t.h2 != nil && t.client != nil { + return + } + tr := &http2.Transport{ + AllowHTTP: false, + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + sni := t.nextSNI() + tlsCfg := &tls.Config{ + ServerName: sni, + InsecureSkipVerify: !t.verifySSL, + NextProtos: []string{"h2", "http/1.1"}, + } + dialer := &net.Dialer{Timeout: 15 * time.Second} + conn, err := dialer.DialContext(ctx, network, net.JoinHostPort(t.connectHost, "443")) + if err != nil { + return nil, err + } + if tcp, ok := conn.(*net.TCPConn); ok { + _ = tcp.SetNoDelay(true) + } + tlsConn := tls.Client(conn, tlsCfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + _ = conn.Close() + return nil, err + } + if tlsConn.ConnectionState().NegotiatedProtocol != "h2" { + _ = tlsConn.Close() + return nil, errors.New("h2 ALPN negotiation failed") + } + return tlsConn, nil + }, + } + client := &http.Client{Transport: tr} + t.h2 = tr + t.client = client + log.Infof("H2 transport ready -> %s", t.connectHost) +} + +func (t *Transport) Request(ctx context.Context, method, path, host string, headers map[string]string, body []byte, timeout time.Duration) (int, map[string]string, []byte, error) { + t.ensure() + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + u := &url.URL{Scheme: "https", Host: host, Path: path} + req, err := http.NewRequestWithContext(ctx, method, u.String(), bytes.NewReader(body)) + if err != nil { + return 0, nil, nil, err + } + for k, v := range headers { + req.Header.Set(k, v) + } + req.Header.Set("accept-encoding", codec.SupportedEncodings()) + req.Host = host + + ctx, cancel := context.WithTimeout(req.Context(), timeout) + defer cancel() + req = req.WithContext(ctx) + + resp, err := t.client.Do(req) + if err != nil { + return 0, nil, nil, err + } + defer resp.Body.Close() + + data, err := io.ReadAll(resp.Body) + if err != nil { + return 0, nil, nil, err + } + respHeaders := map[string]string{} + for k, v := range resp.Header { + if len(v) > 0 { + respHeaders[strings.ToLower(k)] = v[0] + } + } + if enc := respHeaders["content-encoding"]; enc != "" { + data = codec.Decode(data, enc) + } + return resp.StatusCode, respHeaders, data, nil +} + +func (t *Transport) Close() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.h2 != nil { + t.h2.CloseIdleConnections() + } + t.h2 = nil + t.client = nil + return nil +} + +func (t *Transport) nextSNI() string { + idx := atomic.AddUint32(&t.sniIdx, 1) + return t.sniHosts[int(idx)%len(t.sniHosts)] +} diff --git a/internal/lan/lan.go b/internal/lan/lan.go new file mode 100644 index 0000000..c6bf721 --- /dev/null +++ b/internal/lan/lan.go @@ -0,0 +1,95 @@ +package lan + +import ( + "net" + "net/netip" + "os" + "strconv" + "strings" + + "github.com/denuitt1/mhr-cfw/internal/logging" +) + +var log = logging.Get("LAN") + +func GetNetworkInterfaces() map[string][]string { + out := map[string][]string{} + seen := map[string]bool{} + + add := func(label, ip string) { + if ip == "" || seen[ip] || strings.HasPrefix(ip, "127.") { + return + } + seen[ip] = true + out[label] = append(out[label], ip) + } + + if ip := primaryIPv4(); ip != "" { + add("primary", ip) + } + + host, _ := os.Hostname() + if host != "" { + if addrs, err := net.LookupIP(host); err == nil { + for _, a := range addrs { + if a4 := a.To4(); a4 != nil { + add("host", a4.String()) + } + } + } + } + + return out +} + +func GetLANIPs(port int) []string { + ifaces := GetNetworkInterfaces() + var lan []string + seen := map[string]bool{} + for _, ips := range ifaces { + for _, ip := range ips { + addr, err := netip.ParseAddr(ip) + if err != nil { + continue + } + if addr.IsLoopback() || addr.IsUnspecified() { + continue + } + if addr.IsPrivate() || addr.IsLinkLocalUnicast() { + addrStr := ip + ":" + strconv.Itoa(port) + if !seen[addrStr] { + seen[addrStr] = true + lan = append(lan, addrStr) + } + } + } + } + return lan +} + +func LogLANAccess(port int, socksPort *int) { + lanHTTP := GetLANIPs(port) + if len(lanHTTP) > 0 { + log.Infof("LAN HTTP proxy : %s", strings.Join(lanHTTP, ", ")) + } else { + log.Warnf("No LAN IP addresses detected for HTTP proxy") + } + if socksPort != nil { + lanSocks := GetLANIPs(*socksPort) + if len(lanSocks) > 0 { + log.Infof("LAN SOCKS5 proxy : %s", strings.Join(lanSocks, ", ")) + } else { + log.Warnf("No LAN IP addresses detected for SOCKS5 proxy") + } + } +} + +func primaryIPv4() string { + conn, err := net.Dial("udp", "192.0.2.1:80") + if err != nil { + return "" + } + defer conn.Close() + local := conn.LocalAddr().(*net.UDPAddr) + return local.IP.String() +} diff --git a/internal/logging/logging.go b/internal/logging/logging.go new file mode 100644 index 0000000..0300b44 --- /dev/null +++ b/internal/logging/logging.go @@ -0,0 +1,245 @@ +package logging + +import ( + "fmt" + "io" + "os" + "runtime" + "strings" + "sync" + "time" +) + +type Level int + +const ( + Debug Level = iota + Info + Warn + Error +) + +type Logger struct { + name string +} + +var ( + mu sync.RWMutex + globalLvl = Info + colorOn = false + outWriter io.Writer = os.Stderr +) + +func Configure(level string) { + lvl := Info + switch strings.ToUpper(level) { + case "DEBUG": + lvl = Debug + case "WARNING", "WARN": + lvl = Warn + case "ERROR": + lvl = Error + default: + lvl = Info + } + mu.Lock() + globalLvl = lvl + outWriter = os.Stderr + colorOn = supportsColor(os.Stderr) + mu.Unlock() +} + +func Get(name string) *Logger { + return &Logger{name: name} +} + +func (l *Logger) Debugf(format string, args ...any) { + l.log(Debug, format, args...) +} + +func (l *Logger) Infof(format string, args ...any) { + l.log(Info, format, args...) +} + +func (l *Logger) Warnf(format string, args ...any) { + l.log(Warn, format, args...) +} + +func (l *Logger) Errorf(format string, args ...any) { + l.log(Error, format, args...) +} + +func (l *Logger) log(level Level, format string, args ...any) { + mu.RLock() + if level < globalLvl { + mu.RUnlock() + return + } + out := outWriter + useColor := colorOn + mu.RUnlock() + + now := time.Now() + ts := now.Format("15:04:05") + levelLabel := levelText(level) + line := fmt.Sprintf(format, args...) + component := l.name + if len(component) > 8 { + component = component[:8] + } + component = fmt.Sprintf("%-8s", component) + + if useColor { + ts = color("90", ts) + levelLabel = color(levelColor(level), levelLabel) + component = color(componentColor(l.name), "["+component+"]") + } else { + component = "[" + component + "]" + } + + fmt.Fprintf(out, "%s %s %s %s\n", ts, levelLabel, component, line) +} + +func levelText(level Level) string { + switch level { + case Debug: + return "DBG" + case Info: + return "INF" + case Warn: + return "WRN" + case Error: + return "ERR" + default: + return "INF" + } +} + +func levelColor(level Level) string { + switch level { + case Debug: + return "38;5;245" + case Info: + return "38;5;39" + case Warn: + return "38;5;214" + case Error: + return "38;5;203" + default: + return "38;5;39" + } +} + +func componentColor(name string) string { + switch name { + case "Main": + return "38;5;81" + case "Proxy": + return "38;5;75" + case "Fronter": + return "38;5;141" + case "H2": + return "38;5;87" + case "MITM": + return "38;5;208" + case "Cert": + return "38;5;177" + case "LAN": + return "38;5;80" + case "Scanner": + return "38;5;45" + default: + return "38;5;245" + } +} + +func color(code, text string) string { + return "\x1b[" + code + "m" + text + "\x1b[0m" +} + +func bold(s string) string { return "\x1b[1m" + s + "\x1b[0m" } +func dim(s string) string { return "\x1b[2m" + s + "\x1b[0m" } +func teal(s string) string { return "\x1b[1;38;5;45m" + s + "\x1b[0m" } +func faint(s string) string { return "\x1b[38;5;250m" + s + "\x1b[0m" } +func amber(s string) string { return "\x1b[38;5;214m" + s + "\x1b[0m" } +func violet(s string) string { return "\x1b[38;5;141m" + s + "\x1b[0m" } + +func supportsColor(stream *os.File) bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("DFT_NO_COLOR") == "1" { + return false + } + if os.Getenv("FORCE_COLOR") != "" || os.Getenv("DFT_FORCE_COLOR") != "" { + return true + } + info, err := stream.Stat() + if err != nil || (info.Mode()&os.ModeCharDevice) == 0 { + return false + } + if runtime.GOOS != "windows" { + return true + } + return enableVirtualTerminal(stream) +} + + +func enableVirtualTerminal(stream *os.File) bool { + if runtime.GOOS != "windows" { + return true + } + // PowerShell + Windows Terminal already support ANSI; conservative default + return true +} + +func PrintBanner(version string) { + title := "MHR-CFW Go Version" + subtitle := "Domain-Fronted Relay Suite" + credit := "ThisIsDara" + versionTag := "v" + version + + innerWidth := max(76, max(len(title), max(len(subtitle), len(credit)))+8) + line := strings.Repeat("═", innerWidth) + borderTop := "╔ " + line + " ╗" + borderMid := "║" + strings.Repeat(" ", innerWidth) + "║" + borderBot := "╚ " + line + " ╝" + + centerLine := func(text string) string { + pad := innerWidth - len(text) + left := pad / 2 + right := pad - left + return "║" + strings.Repeat(" ", left) + text + strings.Repeat(" ", right) + "║" + } + + if colorOn { + fmt.Fprintln(outWriter) + fmt.Fprintln(outWriter, borderTop) + fmt.Fprintln(outWriter, borderMid) + outLine := "║" + bold(teal(centerLine(title))) + "║" + fmt.Fprintln(outWriter, outLine) + outLine = "║" + faint(centerLine(subtitle)) + "║" + fmt.Fprintln(outWriter, outLine) + outLine = "║" + amber(centerLine(versionTag)) + "║" + fmt.Fprintln(outWriter, outLine) + outLine = "║" + violet(centerLine(credit)) + "║" + fmt.Fprintln(outWriter, outLine) + fmt.Fprintln(outWriter, borderMid) + fmt.Fprintln(outWriter, borderBot) + return + } + + fmt.Println() + fmt.Println(borderTop) + fmt.Println(borderMid) + fmt.Println(centerLine(title)) + fmt.Println(centerLine(subtitle)) + fmt.Println(centerLine(versionTag)) + fmt.Println(centerLine(credit)) + fmt.Println(borderMid) + fmt.Println(borderBot) +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/internal/mitm/mitm.go b/internal/mitm/mitm.go new file mode 100644 index 0000000..864aa6f --- /dev/null +++ b/internal/mitm/mitm.go @@ -0,0 +1,187 @@ +package mitm + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "path/filepath" + "sync" + "time" +) + +var ( + projectRoot = func() string { + wd, err := os.Getwd() + if err != nil { + exe, _ := os.Executable() + return filepath.Dir(exe) + } + return wd + }() + CADir = filepath.Join(projectRoot, "ca") + CAKeyFile = filepath.Join(CADir, "ca.key") + CACertFile = filepath.Join(CADir, "ca.crt") +) + +type Manager struct { + mu sync.Mutex + caKey *rsa.PrivateKey + caCert *x509.Certificate + cache map[string]*tls.Certificate +} + +func NewManager() *Manager { + m := &Manager{ + cache: map[string]*tls.Certificate{}, + } + m.ensureCA() + return m +} + +func (m *Manager) GetServerTLSConfig(domain string) (*tls.Config, error) { + cert, err := m.getCertificate(domain) + if err != nil { + return nil, err + } + return &tls.Config{ + Certificates: []tls.Certificate{*cert}, + NextProtos: []string{"http/1.1"}, + }, nil +} + +func (m *Manager) getCertificate(domain string) (*tls.Certificate, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if cert, ok := m.cache[domain]; ok { + return cert, nil + } + if m.caKey == nil || m.caCert == nil { + m.ensureCA() + } + key, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, err + } + + now := time.Now().UTC() + tmpl := &x509.Certificate{ + SerialNumber: randomSerial(), + Subject: pkix.Name{ + CommonName: domain, + }, + NotBefore: now, + NotAfter: now.AddDate(1, 0, 0), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + if ip := net.ParseIP(domain); ip != nil { + tmpl.IPAddresses = []net.IP{ip} + } else { + tmpl.DNSNames = []string{domain} + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, m.caCert, &key.PublicKey, m.caKey) + if err != nil { + return nil, err + } + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: m.caCert.Raw}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}) + + tlsCert, err := tls.X509KeyPair(append(certPEM, caPEM...), keyPEM) + if err != nil { + return nil, err + } + m.cache[domain] = &tlsCert + return &tlsCert, nil +} + +func (m *Manager) ensureCA() { + if fileExists(CAKeyFile) && fileExists(CACertFile) { + keyPEM, _ := os.ReadFile(CAKeyFile) + certPEM, _ := os.ReadFile(CACertFile) + key, _ := parsePrivateKeyPEM(keyPEM) + cert, _ := parseCertPEM(certPEM) + if key != nil && cert != nil { + m.caKey = key + m.caCert = cert + return + } + } + + _ = os.MkdirAll(CADir, 0755) + key, _ := rsa.GenerateKey(rand.Reader, 2048) + now := time.Now().UTC() + ca := &x509.Certificate{ + SerialNumber: randomSerial(), + Subject: pkix.Name{ + CommonName: "mhr-cfw", + Organization: []string{"mhr-cfw"}, + }, + NotBefore: now, + NotAfter: now.AddDate(10, 0, 0), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign, + BasicConstraintsValid: true, + IsCA: true, + MaxPathLen: 0, + } + der, _ := x509.CreateCertificate(rand.Reader, ca, ca, &key.PublicKey, key) + cert, _ := x509.ParseCertificate(der) + + m.caKey = key + m.caCert = cert + + writePEM(CAKeyFile, "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(key)) + writePEM(CACertFile, "CERTIFICATE", der) +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil +} + +func writePEM(path, typ string, der []byte) { + f, _ := os.Create(path) + defer f.Close() + _ = pem.Encode(f, &pem.Block{Type: typ, Bytes: der}) + if os.PathSeparator == '/' { + _ = os.Chmod(path, 0600) + } +} + +func parsePrivateKeyPEM(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, nil + } + if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + if k, ok := key.(*rsa.PrivateKey); ok { + return k, nil + } + } + return nil, nil +} + +func parseCertPEM(pemBytes []byte) (*x509.Certificate, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, nil + } + return x509.ParseCertificate(block.Bytes) +} + +func randomSerial() *big.Int { + serialLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serial, _ := rand.Int(rand.Reader, serialLimit) + return serial +} diff --git a/internal/proxy/proxy_server.go b/internal/proxy/proxy_server.go new file mode 100644 index 0000000..b97b595 --- /dev/null +++ b/internal/proxy/proxy_server.go @@ -0,0 +1,589 @@ +package proxy + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "net/textproto" + "regexp" + "strconv" + "strings" + "sync" + "time" + + "github.com/denuitt1/mhr-cfw/internal/config" + "github.com/denuitt1/mhr-cfw/internal/constants" + "github.com/denuitt1/mhr-cfw/internal/fronter" + "github.com/denuitt1/mhr-cfw/internal/logging" + "github.com/denuitt1/mhr-cfw/internal/mitm" +) + +var log = logging.Get("Proxy") + +var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`) + +type ResponseCache struct { + mu sync.Mutex + store map[string]cacheEntry + order []string + size int + max int + Hits int + Misses int +} + +type cacheEntry struct { + raw []byte + expires time.Time +} + +func NewResponseCache(maxMB int) *ResponseCache { + return &ResponseCache{store: map[string]cacheEntry{}, order: []string{}, max: maxMB * 1024 * 1024} +} + +func (c *ResponseCache) Get(url string) []byte { + c.mu.Lock() + defer c.mu.Unlock() + entry, ok := c.store[url] + if !ok { + c.Misses++ + return nil + } + if time.Now().After(entry.expires) { + c.size -= len(entry.raw) + delete(c.store, url) + for i, u := range c.order { + if u == url { + c.order = append(c.order[:i], c.order[i+1:]...) + break + } + } + c.Misses++ + return nil + } + c.Hits++ + return entry.raw +} + +func (c *ResponseCache) Put(url string, raw []byte, ttl int) { + if len(raw) == 0 { + return + } + size := len(raw) + if size > c.max/4 { + return + } + c.mu.Lock() + defer c.mu.Unlock() + for c.size+size > c.max && len(c.store) > 0 { + oldURL := c.order[0] + c.size -= len(c.store[oldURL].raw) + delete(c.store, oldURL) + c.order = c.order[1:] + } + if old, ok := c.store[url]; ok { + for i, u := range c.order { + if u == url { + c.order = append(c.order[:i], c.order[i+1:]...) + break + } + } + c.size -= len(old.raw) + } + c.store[url] = cacheEntry{raw: raw, expires: time.Now().Add(time.Duration(ttl) * time.Second)} + c.order = append(c.order, url) + c.size += size +} + +func (c *ResponseCache) ParseTTL(raw []byte, urlStr string) int { + sep := []byte("\r\n\r\n") + idx := bytes.Index(raw, sep) + if idx < 0 { + return 0 + } + head := strings.ToLower(string(raw[:idx])) + if !strings.HasPrefix(string(raw[:20]), "HTTP/1.1 200") { + return 0 + } + if strings.Contains(head, "no-store") || strings.Contains(head, "private") || strings.Contains(head, "set-cookie:") { + return 0 + } + if m := maxAgeRegex.FindStringSubmatch(head); len(m) == 2 { + v, _ := strconv.Atoi(m[1]) + if v > constants.CacheTTLMax { + return constants.CacheTTLMax + } + return v + } + path := strings.ToLower(strings.Split(urlStr, "?")[0]) + for _, ext := range constants.StaticExts { + if strings.HasSuffix(path, ext) { + return constants.CacheTTLStaticLong + } + } + if strings.Contains(head, "image/") || strings.Contains(head, "font/") { + return constants.CacheTTLStaticLong + } + if strings.Contains(head, "text/css") || strings.Contains(head, "javascript") { + return constants.CacheTTLStaticMed + } + if strings.Contains(head, "text/html") || strings.Contains(head, "application/json") { + return 0 + } + return 0 +} + +type Server struct { + host string + port int + socksEnabled bool + socksHost string + socksPort int + + fronter *fronter.DomainFronter + mitm *mitm.Manager + cache *ResponseCache + + directFailUntil map[string]time.Time + mu sync.Mutex + + servers []net.Listener + wg sync.WaitGroup + ctx context.Context +} + +func NewServer(cfg config.Config) (*Server, error) { + host := cfg.GetString("listen_host", "127.0.0.1") + port := cfg.GetInt("listen_port", 8080) + socksEnabled := cfg.GetBool("socks5_enabled", true) + socksHost := cfg.GetString("socks5_host", host) + socksPort := cfg.GetInt("socks5_port", 1080) + if socksEnabled && socksHost == host && socksPort == port { + return nil, fmt.Errorf("listen_port and socks5_port must differ on the same host (both set to %d on %s)", port, host) + } + + return &Server{ + host: host, + port: port, + socksEnabled: socksEnabled, + socksHost: socksHost, + socksPort: socksPort, + fronter: fronter.New(cfg), + mitm: mitm.NewManager(), + cache: NewResponseCache(constants.CacheMaxMB), + directFailUntil: map[string]time.Time{}, + }, nil +} + +func (s *Server) Start(ctx context.Context) error { + s.ctx = ctx + ln, err := net.Listen("tcp", net.JoinHostPort(s.host, strconv.Itoa(s.port))) + if err != nil { + return err + } + s.servers = append(s.servers, ln) + log.Infof("HTTP proxy listening on %s:%d", s.host, s.port) + + if s.socksEnabled { + socksLn, err := net.Listen("tcp", net.JoinHostPort(s.socksHost, strconv.Itoa(s.socksPort))) + if err != nil { + log.Errorf("SOCKS5 listener failed on %s:%d: %v", s.socksHost, s.socksPort, err) + } else { + s.servers = append(s.servers, socksLn) + log.Infof("SOCKS5 proxy listening on %s:%d", s.socksHost, s.socksPort) + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.acceptLoop(socksLn, s.handleSocksConn) + }() + } + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.acceptLoop(ln, s.handleHTTPConn) + }() + + <-ctx.Done() + for _, l := range s.servers { + _ = l.Close() + } + _ = s.fronter.Close() + s.wg.Wait() + log.Infof("Server stopped") + return nil +} + +func (s *Server) acceptLoop(ln net.Listener, handler func(net.Conn)) { + defer ln.Close() + for { + conn, err := ln.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + continue + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + handler(conn) + }() + } +} + +func (s *Server) handleHTTPConn(conn net.Conn) { + defer conn.Close() + conn.SetDeadline(time.Now().Add(30 * time.Second)) + reader := bufio.NewReader(conn) + line, err := reader.ReadString('\n') + if err != nil { + return + } + + headers := []string{line} + for { + ln, err := reader.ReadString('\n') + if err != nil { + return + } + headers = append(headers, ln) + if ln == "\r\n" || ln == "\n" { + break + } + if sumLen(headers) > constants.MaxHeaderBytes { + return + } + } + + parts := strings.Split(strings.TrimSpace(line), " ") + if len(parts) < 2 { + return + } + method := strings.ToUpper(parts[0]) + if method == "CONNECT" { + s.handleConnect(conn, reader, parts[1]) + return + } + s.handlePlainHTTP(conn, reader, headers) +} + +func (s *Server) handleConnect(conn net.Conn, reader *bufio.Reader, target string) { + host, port := splitHostPort(target, 443) + log.Infof("CONNECT -> %s:%d", host, port) + _, _ = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + s.handleTunnel(host, port, conn, reader) +} + +func (s *Server) handleTunnel(host string, port int, conn net.Conn, reader *bufio.Reader) { + if port == 443 { + cfg, err := s.mitm.GetServerTLSConfig(host) + if err != nil { + return + } + tlsConn := tls.Server(conn, cfg) + if err := tlsConn.Handshake(); err != nil { + return + } + s.relayHTTPStream(host, port, tlsConn) + return + } + s.relayHTTPStream(host, port, conn) +} + +func (s *Server) relayHTTPStream(host string, port int, conn net.Conn) { + reader := bufio.NewReader(conn) + for { + conn.SetDeadline(time.Now().Add(time.Duration(constants.ClientIdleTimeout) * time.Second)) + line, err := reader.ReadString('\n') + if err != nil { + return + } + if line == "\r\n" || line == "\n" { + continue + } + headers := []string{line} + for { + ln, err := reader.ReadString('\n') + if err != nil { + return + } + headers = append(headers, ln) + if ln == "\r\n" || ln == "\n" { + break + } + if sumLen(headers) > constants.MaxHeaderBytes { + return + } + } + method, path := parseRequestLine(line) + body, err := readBody(reader, headers) + if err != nil { + return + } + headerMap := parseHeaders(headers[1:]) + + urlStr := normalizeURL(host, port, path) + log.Infof("MITM -> %s %s", method, urlStr) + + origin := headerValue(headerMap, "origin") + acrMethod := headerValue(headerMap, "access-control-request-method") + acrHeaders := headerValue(headerMap, "access-control-request-headers") + if strings.ToUpper(method) == "OPTIONS" && acrMethod != "" { + resp := corsPreflight(origin, acrMethod, acrHeaders) + _, _ = conn.Write(resp) + continue + } + + response := s.fronter.Relay(method, urlStr, headerMap, body) + if origin != "" { + response = injectCORSHeaders(response, origin) + } + _, _ = conn.Write(response) + } +} + +func (s *Server) handlePlainHTTP(conn net.Conn, reader *bufio.Reader, headers []string) { + method, path := parseRequestLine(headers[0]) + body, err := readBody(reader, headers) + if err != nil { + return + } + headerMap := parseHeaders(headers[1:]) + + origin := headerValue(headerMap, "origin") + acrMethod := headerValue(headerMap, "access-control-request-method") + acrHeaders := headerValue(headerMap, "access-control-request-headers") + if strings.ToUpper(method) == "OPTIONS" && acrMethod != "" { + resp := corsPreflight(origin, acrMethod, acrHeaders) + _, _ = conn.Write(resp) + return + } + + urlStr := path + response := s.fronter.Relay(method, urlStr, headerMap, body) + if origin != "" { + response = injectCORSHeaders(response, origin) + } + _, _ = conn.Write(response) +} + +func (s *Server) handleSocksConn(conn net.Conn) { + defer conn.Close() + conn.SetDeadline(time.Now().Add(15 * time.Second)) + buf := make([]byte, 2) + if _, err := io.ReadFull(conn, buf); err != nil { + return + } + if buf[0] != 5 { + return + } + methods := make([]byte, int(buf[1])) + if _, err := io.ReadFull(conn, methods); err != nil { + return + } + conn.Write([]byte{0x05, 0x00}) + + request := make([]byte, 4) + if _, err := io.ReadFull(conn, request); err != nil { + return + } + if request[0] != 5 || request[1] != 0x01 { + conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + + addrType := request[3] + var host string + switch addrType { + case 0x01: + ip := make([]byte, 4) + if _, err := io.ReadFull(conn, ip); err != nil { + return + } + host = net.IP(ip).String() + case 0x03: + ln := make([]byte, 1) + if _, err := io.ReadFull(conn, ln); err != nil { + return + } + name := make([]byte, int(ln[0])) + if _, err := io.ReadFull(conn, name); err != nil { + return + } + host = string(name) + case 0x04: + ip := make([]byte, 16) + if _, err := io.ReadFull(conn, ip); err != nil { + return + } + host = net.IP(ip).String() + default: + conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + return + } + portBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, portBuf); err != nil { + return + } + port := int(portBuf[0])<<8 | int(portBuf[1]) + + log.Infof("SOCKS5 CONNECT -> %s:%d", host, port) + conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) + + s.handleTunnel(host, port, conn, bufio.NewReader(conn)) +} + +func sumLen(lines []string) int { + count := 0 + for _, l := range lines { + count += len(l) + } + return count +} + +func parseRequestLine(line string) (string, string) { + parts := strings.Split(strings.TrimSpace(line), " ") + if len(parts) < 2 { + return "GET", "/" + } + return parts[0], parts[1] +} + +func parseHeaders(lines []string) map[string]string { + h := map[string]string{} + for _, ln := range lines { + ln = strings.TrimRight(ln, "\r\n") + if ln == "" { + continue + } + parts := strings.SplitN(ln, ":", 2) + if len(parts) != 2 { + continue + } + key := textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(parts[0])) + val := strings.TrimSpace(parts[1]) + h[key] = val + } + return h +} + +func readBody(reader *bufio.Reader, headers []string) ([]byte, error) { + cl := 0 + for _, ln := range headers { + if strings.HasPrefix(strings.ToLower(ln), "content-length:") { + v := strings.TrimSpace(strings.TrimPrefix(ln, "Content-Length:")) + n, err := strconv.Atoi(v) + if err != nil || n < 0 { + return nil, errors.New("invalid Content-Length") + } + cl = n + } + } + if cl > constants.MaxRequestBodyBytes { + return nil, errors.New("request body too large") + } + if cl == 0 { + return nil, nil + } + buf := make([]byte, cl) + _, err := io.ReadFull(reader, buf) + return buf, err +} + +func normalizeURL(host string, port int, path string) string { + if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { + return path + } + scheme := "http" + if port == 443 { + scheme = "https" + } + if port == 80 || port == 443 { + return fmt.Sprintf("%s://%s%s", scheme, host, path) + } + return fmt.Sprintf("%s://%s:%d%s", scheme, host, port, path) +} + +func headerValue(headers map[string]string, name string) string { + for k, v := range headers { + if strings.ToLower(k) == name { + return v + } + } + return "" +} + +func corsPreflight(origin, acrMethod, acrHeaders string) []byte { + allowOrigin := origin + if allowOrigin == "" { + allowOrigin = "*" + } + allowMethods := "GET, POST, PUT, DELETE, PATCH, OPTIONS" + if acrMethod != "" { + allowMethods = acrMethod + ", " + allowMethods + } + allowHeaders := acrHeaders + if allowHeaders == "" { + allowHeaders = "*" + } + resp := "HTTP/1.1 204 No Content\r\n" + + "Access-Control-Allow-Origin: " + allowOrigin + "\r\n" + + "Access-Control-Allow-Methods: " + allowMethods + "\r\n" + + "Access-Control-Allow-Headers: " + allowHeaders + "\r\n" + + "Access-Control-Allow-Credentials: true\r\n" + + "Access-Control-Max-Age: 86400\r\n" + + "Vary: Origin\r\n" + + "Content-Length: 0\r\n\r\n" + return []byte(resp) +} + +func injectCORSHeaders(response []byte, origin string) []byte { + sep := []byte("\r\n\r\n") + idx := bytes.Index(response, sep) + if idx < 0 { + return response + } + head := string(response[:idx]) + body := response[idx+4:] + lines := strings.Split(head, "\r\n") + filtered := []string{} + for _, ln := range lines { + low := strings.ToLower(ln) + if strings.HasPrefix(low, "access-control-") { + continue + } + filtered = append(filtered, ln) + } + allowOrigin := origin + if allowOrigin == "" { + allowOrigin = "*" + } + filtered = append(filtered, + "Access-Control-Allow-Origin: "+allowOrigin, + "Access-Control-Allow-Credentials: true", + "Access-Control-Allow-Methods: GET, POST, PUT, DELETE, PATCH, OPTIONS", + "Access-Control-Allow-Headers: *", + "Access-Control-Expose-Headers: *", + "Vary: Origin", + ) + newHead := strings.Join(filtered, "\r\n") + "\r\n\r\n" + return append([]byte(newHead), body...) +} + +func splitHostPort(target string, defPort int) (string, int) { + if strings.Contains(target, ":") { + parts := strings.Split(target, ":") + if len(parts) >= 2 { + port, _ := strconv.Atoi(parts[len(parts)-1]) + return strings.Join(parts[:len(parts)-1], ":"), port + } + } + return target, defPort +} diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go new file mode 100644 index 0000000..9635c58 --- /dev/null +++ b/internal/scanner/scanner.go @@ -0,0 +1,127 @@ +package scanner + +import ( + "crypto/tls" + "fmt" + "net" + "sort" + "strings" + "time" + + "github.com/denuitt1/mhr-cfw/internal/constants" +) + +type ProbeResult struct { + IP string + LatencyMS *int + Error string +} + +func (r ProbeResult) OK() bool { + return r.LatencyMS != nil +} + +func ScanSync(frontDomain string) bool { + results := run(frontDomain) + + okCount := 0 + fmt.Printf("\nIP LATENCY STATUS\n") + fmt.Printf("------------------- ----------- -------------------------\n") + for _, r := range results { + if r.OK() { + fmt.Printf("%-19s %8dms OK\n", r.IP, *r.LatencyMS) + okCount++ + } else { + fmt.Printf("%-19s %-11s %s\n", r.IP, "-", r.Error) + } + } + fmt.Printf("\nResult: %d / %d reachable\n", okCount, len(results)) + + if okCount == 0 { + fmt.Println("No Google IPs reachable from this network.\n") + return false + } + + fastest := []ProbeResult{} + for _, r := range results { + if r.OK() { + fastest = append(fastest, r) + } + if len(fastest) == 3 { + break + } + } + fmt.Println("\nTop 3 fastest IPs:") + for i, r := range fastest { + fmt.Printf(" %d. %s (%dms)\n", i+1, r.IP, *r.LatencyMS) + } + fmt.Printf("\nRecommended: Set \"google_ip\": \"%s\" in config.json\n\n", fastest[0].IP) + return true +} + +func run(frontDomain string) []ProbeResult { + timeout := time.Duration(constants.GoogleScannerTimeout) * time.Second + sem := make(chan struct{}, constants.GoogleScannerConcurrency) + results := make([]ProbeResult, 0, len(constants.CandidateIPs)) + ch := make(chan ProbeResult, len(constants.CandidateIPs)) + + for _, ip := range constants.CandidateIPs { + ip := ip + sem <- struct{}{} + go func() { + defer func() { <-sem }() + ch <- probeIP(ip, frontDomain, timeout) + }() + } + + for i := 0; i < len(constants.CandidateIPs); i++ { + results = append(results, <-ch) + } + + sort.Slice(results, func(i, j int) bool { + ri, rj := results[i], results[j] + if ri.OK() != rj.OK() { + return ri.OK() + } + if !ri.OK() { + return ri.IP < rj.IP + } + return *ri.LatencyMS < *rj.LatencyMS + }) + return results +} + +func probeIP(ip, sni string, timeout time.Duration) ProbeResult { + start := time.Now() + dialer := &net.Dialer{Timeout: timeout} + raw, err := dialer.Dial("tcp", net.JoinHostPort(ip, "443")) + if err != nil { + return ProbeResult{IP: ip, Error: "network error"} + } + defer raw.Close() + + cfg := &tls.Config{ + ServerName: sni, + InsecureSkipVerify: true, + } + conn := tls.Client(raw, cfg) + _ = conn.SetDeadline(time.Now().Add(timeout)) + if err := conn.Handshake(); err != nil { + return ProbeResult{IP: ip, Error: "handshake failed"} + } + + req := fmt.Sprintf("HEAD / HTTP/1.1\r\nHost: %s\r\nConnection: close\r\n\r\n", sni) + if _, err := conn.Write([]byte(req)); err != nil { + return ProbeResult{IP: ip, Error: "write failed"} + } + buf := make([]byte, 256) + n, err := conn.Read(buf) + if err != nil || n == 0 { + return ProbeResult{IP: ip, Error: "empty response"} + } + if !strings.HasPrefix(string(buf[:n]), "HTTP/") { + return ProbeResult{IP: ip, Error: "invalid response"} + } + ms := int(time.Since(start).Milliseconds()) + return ProbeResult{IP: ip, LatencyMS: &ms} +} diff --git a/internal/setup/wizard.go b/internal/setup/wizard.go new file mode 100644 index 0000000..01f33cc --- /dev/null +++ b/internal/setup/wizard.go @@ -0,0 +1,366 @@ +package setup + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +func RunInteractiveWizard(configPath string) error { + cfg := loadBaseConfig() + + reader := bufio.NewReader(os.Stdin) + ui := newWizardUI() + + ui.Space() + ui.Title("mhr-cfw setup") + ui.Subtitle("Guided configuration for the local relay proxy") + ui.Space() + + if _, err := os.Stat(configPath); err == nil { + if !promptYesNo(reader, ui, "config.json already exists. Overwrite?", false) { + ui.Muted("Nothing changed.") + return nil + } + } + + ui.Section("Shared password") + ui.Muted("Must match AUTH_KEY inside apps_script/Code.gs") + cfg["auth_key"] = prompt(reader, ui, "auth_key", randomAuthKey(32)) + + cfg = configureAppsScript(reader, cfg, ui) + cfg = configureNetwork(reader, cfg, ui) + + if err := writeConfig(configPath, cfg, ui); err != nil { + return err + } + + ui.Space() + ui.Ok("wrote " + filepath.Base(configPath)) + ui.Space() + ui.Section("Next step") + ui.Code("mhr-cfw") + ui.Space() + ui.Warn("AUTH_KEY inside apps_script/Code.gs must match the auth_key you entered") + return nil +} + +type wizardUI struct { + color bool +} + +func newWizardUI() *wizardUI { + color := supportsColor() + return &wizardUI{color: color} +} + +func (w *wizardUI) Space() { + fmt.Println() +} + +func (w *wizardUI) Title(text string) { + line := strings.Repeat("─", max(48, len(text)+12)) + if w.color { + fmt.Println(dim(line)) + fmt.Println(bold(cyan(" " + text + " "))) + fmt.Println(dim(line)) + return + } + fmt.Println(line) + fmt.Println(" " + text) + fmt.Println(line) +} + +func (w *wizardUI) Subtitle(text string) { + if w.color { + fmt.Println(dim(text)) + return + } + fmt.Println(text) +} + +func (w *wizardUI) Section(text string) { + if w.color { + fmt.Println(bold(cyan(text))) + return + } + fmt.Println(text) +} + +func (w *wizardUI) Step(n int, text string) { + label := fmt.Sprintf("%d.", n) + if w.color { + fmt.Println(dim(label), text) + return + } + fmt.Println(label, text) +} + +func (w *wizardUI) Code(text string) { + if w.color { + fmt.Println(dim(" $"), bold(text)) + return + } + fmt.Println(" $", text) +} + +func (w *wizardUI) Prompt(question, hint string) { + if hint != "" { + if w.color { + fmt.Printf("%s %s %s: ", cyan("?"), question, dim("["+hint+"]")) + return + } + fmt.Printf("? %s [%s]: ", question, hint) + return + } + if w.color { + fmt.Printf("%s %s: ", cyan("?"), question) + return + } + fmt.Printf("? %s: ", question) +} + +func (w *wizardUI) Ok(text string) { + if w.color { + fmt.Println(green("[OK]"), text) + return + } + fmt.Println("[OK]", text) +} + +func (w *wizardUI) Warn(text string) { + if w.color { + fmt.Println(yellow("!"), text) + return + } + fmt.Println("!", text) +} + +func (w *wizardUI) Error(text string) { + if w.color { + fmt.Println(red("!"), text) + return + } + fmt.Println("!", text) +} + +func (w *wizardUI) Muted(text string) { + if w.color { + fmt.Println(dim(text)) + return + } + fmt.Println(text) +} + +func supportsColor() bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("DFT_NO_COLOR") == "1" { + return false + } + if !isTTY(os.Stdout) { + return false + } + return true +} + +func isTTY(f *os.File) bool { + info, err := f.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func bold(s string) string { return "\x1b[1m" + s + "\x1b[0m" } +func dim(s string) string { return "\x1b[2m" + s + "\x1b[0m" } +func cyan(s string) string { return "\x1b[36m" + s + "\x1b[0m" } +func green(s string) string { return "\x1b[32m" + s + "\x1b[0m" } +func yellow(s string) string { return "\x1b[33m" + s + "\x1b[0m" } +func red(s string) string { return "\x1b[31m" + s + "\x1b[0m" } + +func max(a, b int) int { + if a > b { + return a + } + return b +} + +func loadBaseConfig() map[string]any { + return map[string]any{ + "mode": "apps_script", + "google_ip": "216.239.38.120", + "front_domain": "www.google.com", + "listen_host": "127.0.0.1", + "listen_port": 8085, + "socks5_enabled": true, + "socks5_port": 1080, + "log_level": "INFO", + "verify_ssl": true, + "lan_sharing": false, + "relay_timeout": 25, + "tls_connect_timeout": 15, + "tcp_connect_timeout": 10, + "max_response_body_bytes": 200 * 1024 * 1024, + "chunked_download_min_size": 5 * 1024 * 1024, + "chunked_download_chunk_size": 512 * 1024, + "chunked_download_max_parallel": 8, + "chunked_download_max_chunks": 256, + "hosts": map[string]string{}, + } +} + +func configureAppsScript(r *bufio.Reader, cfg map[string]any, ui *wizardUI) map[string]any { + ui.Section("Google Apps Script setup") + ui.Step(1, "Open https://script.google.com -> New project") + ui.Step(2, "Paste apps_script/Code.gs from this repo into the editor") + ui.Step(3, "Set AUTH_KEY in Code.gs to the password above") + ui.Step(4, "Deploy -> New deployment -> Web app") + ui.Step(5, "Execute as: Me | Who has access: Anyone") + ui.Step(6, "Copy the Deployment ID and paste it here") + ui.Space() + + idsRaw := prompt(r, ui, "Deployment ID(s) - comma-separated for load balancing", "") + ids := []string{} + for _, v := range strings.Split(idsRaw, ",") { + v = strings.TrimSpace(v) + if v != "" { + ids = append(ids, v) + } + } + if len(ids) == 1 { + cfg["script_id"] = ids[0] + delete(cfg, "script_ids") + } else if len(ids) > 1 { + cfg["script_ids"] = ids + delete(cfg, "script_id") + } + return cfg +} + +func configureNetwork(r *bufio.Reader, cfg map[string]any, ui *wizardUI) map[string]any { + ui.Section("Network settings") + ui.Muted("Press enter to accept defaults") + ui.Space() + + lanSharing := promptYesNo(r, ui, "Enable LAN sharing?", boolVal(cfg["lan_sharing"])) + cfg["lan_sharing"] = lanSharing + + defaultHost := strVal(cfg["listen_host"]) + if lanSharing && defaultHost == "127.0.0.1" { + defaultHost = "0.0.0.0" + } + cfg["listen_host"] = prompt(r, ui, "Listen host", defaultHost) + + port := prompt(r, ui, "HTTP proxy port", fmt.Sprintf("%v", cfg["listen_port"])) + cfg["listen_port"] = toInt(port, 8085) + + socks := promptYesNo(r, ui, "Enable SOCKS5 proxy?", boolVal(cfg["socks5_enabled"])) + cfg["socks5_enabled"] = socks + if socks { + sport := prompt(r, ui, "SOCKS5 port", fmt.Sprintf("%v", cfg["socks5_port"])) + cfg["socks5_port"] = toInt(sport, 1080) + } + return cfg +} + +func writeConfig(path string, cfg map[string]any, ui *wizardUI) error { + if _, err := os.Stat(path); err == nil { + backup := strings.TrimSuffix(path, ".json") + ".json.bak" + _ = copyFile(path, backup) + ui.Muted("existing config.json backed up to " + filepath.Base(backup)) + } + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + return enc.Encode(cfg) +} + +func copyFile(src, dst string) error { + input, err := os.ReadFile(src) + if err != nil { + return err + } + return os.WriteFile(dst, input, 0644) +} + +func prompt(r *bufio.Reader, ui *wizardUI, question, def string) string { + for { + if def != "" { + ui.Prompt(question, def) + } else { + ui.Prompt(question, "") + } + raw, _ := r.ReadString('\n') + raw = strings.TrimSpace(raw) + if raw == "" && def != "" { + return def + } + if raw != "" { + return raw + } + ui.Error("value required") + } +} + +func promptYesNo(r *bufio.Reader, ui *wizardUI, question string, def bool) bool { + hint := "Y/n" + if !def { + hint = "y/N" + } + for { + ui.Prompt(question, hint) + raw, _ := r.ReadString('\n') + raw = strings.TrimSpace(strings.ToLower(raw)) + if raw == "" { + return def + } + if raw == "y" || raw == "yes" { + return true + } + if raw == "n" || raw == "no" { + return false + } + } +} + +func randomAuthKey(length int) string { + const alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + out := make([]byte, length) + seed := time.Now().UnixNano() + for i := range out { + seed = (seed*1664525 + 1013904223) & 0x7fffffff + out[i] = alphabet[int(seed)%len(alphabet)] + } + return string(out) +} + +func boolVal(v any) bool { + if b, ok := v.(bool); ok { + return b + } + return false +} + +func strVal(v any) string { + if s, ok := v.(string); ok { + return s + } + return "" +} + +func toInt(s string, def int) int { + i, err := strconv.Atoi(strings.TrimSpace(s)) + if err != nil { + return def + } + return i +} diff --git a/internal/tui/menu.go b/internal/tui/menu.go new file mode 100644 index 0000000..650636f --- /dev/null +++ b/internal/tui/menu.go @@ -0,0 +1,148 @@ +package tui + +import ( + "bufio" + "fmt" + "os" + "strconv" + "strings" +) + +type Option struct { + Key int + Label string + Handler func() error +} + +type Menu struct { + Title string + Options []Option +} + +func (m *Menu) Run() error { + reader := bufio.NewReader(os.Stdin) + for { + clearScreen() + m.render() + fmt.Print("Select an option: ") + line, _ := reader.ReadString('\n') + line = strings.TrimSpace(line) + if line == "" { + continue + } + idx, err := strconv.Atoi(line) + if err != nil { + fmt.Println("Invalid selection.") + continue + } + for _, opt := range m.Options { + if opt.Key == idx { + if opt.Handler == nil { + return nil + } + if err := opt.Handler(); err != nil { + fmt.Println("Error:", err) + } + fmt.Print("\nPress Enter to return to menu...") + _, _ = reader.ReadString('\n') + break + } + } + } +} + +func (m *Menu) render() { + useColor := supportsColor() + width := max(70, len(m.Title)+16) + borderTop := "╔ " + strings.Repeat("═", width) + " ╗" + borderMid := "╠" + strings.Repeat("═", width+2) + "╣" + borderBot := "╚ " + strings.Repeat("═", width) + " ╝" + inner := "║" + strings.Repeat(" ", width+2) + "║" + tag := "Mhr-Cfw-Go V1.0" + link := "https://github.com/ThisIsDara/" + + centerText := func(text string) string { + pad := width + 2 - len(text) + left := pad / 2 + right := pad - left + return strings.Repeat(" ", left) + text + strings.Repeat(" ", right) + } + + fmt.Println() + fmt.Println(borderTop) + fmt.Println(inner) + + if useColor { + fmt.Println(cyan(centerText(tag))) + fmt.Println(faint(centerText(link))) + } else { + fmt.Println(centerText(tag)) + fmt.Println(centerText(link)) + } + + fmt.Println(inner) + fmt.Println(borderMid) + + for _, opt := range m.Options { + label := fmt.Sprintf("%d) %s", opt.Key, opt.Label) + if useColor { + fmt.Println(" " + violet(">") + " " + bold(ice(fmt.Sprintf("%d)", opt.Key))) + " " + label[3:]) + } else { + fmt.Println(" * " + label) + } + } + + if useColor { + fmt.Println(dim(borderBot)) + } else { + fmt.Println(borderBot) + } + fmt.Println() +} + +func supportsColor() bool { + if os.Getenv("NO_COLOR") != "" || os.Getenv("DFT_NO_COLOR") == "1" { + return false + } + if !isTTY(os.Stdout) { + return false + } + return true +} + +func clearScreen() { + if !supportsColor() { + return + } + fmt.Print("\x1b[2J\x1b[H") +} + +func isTTY(f *os.File) bool { + info, err := f.Stat() + if err != nil { + return false + } + return (info.Mode() & os.ModeCharDevice) != 0 +} + +func pad(s string, width int) string { + if len(s) >= width { + return s[:width] + } + return s + strings.Repeat(" ", width-len(s)) +} + +func bold(s string) string { return "\x1b[1m" + s + "\x1b[0m" } +func dim(s string) string { return "\x1b[2m" + s + "\x1b[0m" } +func faint(s string) string { return "\x1b[38;5;250m" + s + "\x1b[0m" } +func teal(s string) string { return "\x1b[1;38;5;45m" + s + "\x1b[0m" } +func ice(s string) string { return "\x1b[1;38;5;81m" + s + "\x1b[0m" } +func violet(s string) string { return "\x1b[38;5;141m" + s + "\x1b[0m" } +func cyan(s string) string { return "\x1b[1;36m" + s + "\x1b[0m" } + +func max(a, b int) int { + if a > b { + return a + } + return b +}