package proxy import ( "bufio" "bytes" "context" "crypto/tls" "errors" "fmt" "io" "net" "net/url" "net/textproto" "regexp" "strconv" "strings" "sync" "time" "github.com/denuitt1/mhr-cfw/internal/config" "github.com/denuitt1/mhr-cfw/internal/constants" "github.com/denuitt1/mhr-cfw/internal/fronter" "github.com/denuitt1/mhr-cfw/internal/logging" "github.com/denuitt1/mhr-cfw/internal/mitm" ) var log = logging.Get("Proxy") var maxAgeRegex = regexp.MustCompile(`max-age=(\d+)`) type ResponseCache struct { mu sync.Mutex store map[string]cacheEntry order []string size int max int Hits int Misses int } type cacheEntry struct { raw []byte expires time.Time } func NewResponseCache(maxMB int) *ResponseCache { return &ResponseCache{store: map[string]cacheEntry{}, order: []string{}, max: maxMB * 1024 * 1024} } func (c *ResponseCache) Get(url string) []byte { c.mu.Lock() defer c.mu.Unlock() entry, ok := c.store[url] if !ok { c.Misses++ return nil } if time.Now().After(entry.expires) { c.size -= len(entry.raw) delete(c.store, url) for i, u := range c.order { if u == url { c.order = append(c.order[:i], c.order[i+1:]...) break } } c.Misses++ return nil } c.Hits++ return entry.raw } func (c *ResponseCache) Put(url string, raw []byte, ttl int) { if len(raw) == 0 { return } size := len(raw) if size > c.max/4 { return } c.mu.Lock() defer c.mu.Unlock() for c.size+size > c.max && len(c.store) > 0 { oldURL := c.order[0] c.size -= len(c.store[oldURL].raw) delete(c.store, oldURL) c.order = c.order[1:] } if old, ok := c.store[url]; ok { for i, u := range c.order { if u == url { c.order = append(c.order[:i], c.order[i+1:]...) break } } c.size -= len(old.raw) } c.store[url] = cacheEntry{raw: raw, expires: time.Now().Add(time.Duration(ttl) * time.Second)} c.order = append(c.order, url) c.size += size } func (c *ResponseCache) ParseTTL(raw []byte, urlStr string) int { sep := []byte("\r\n\r\n") idx := bytes.Index(raw, sep) if idx < 0 { return 0 } head := strings.ToLower(string(raw[:idx])) if !strings.HasPrefix(string(raw[:20]), "HTTP/1.1 200") { return 0 } if strings.Contains(head, "no-store") || strings.Contains(head, "private") || strings.Contains(head, "set-cookie:") { return 0 } if m := maxAgeRegex.FindStringSubmatch(head); len(m) == 2 { v, _ := strconv.Atoi(m[1]) if v > constants.CacheTTLMax { return constants.CacheTTLMax } return v } path := strings.ToLower(strings.Split(urlStr, "?")[0]) for _, ext := range constants.StaticExts { if strings.HasSuffix(path, ext) { return constants.CacheTTLStaticLong } } if strings.Contains(head, "image/") || strings.Contains(head, "font/") { return constants.CacheTTLStaticLong } if strings.Contains(head, "text/css") || strings.Contains(head, "javascript") { return constants.CacheTTLStaticMed } if strings.Contains(head, "text/html") || strings.Contains(head, "application/json") { return 0 } return 0 } type Server struct { host string port int socksEnabled bool socksHost string socksPort int fronter *fronter.DomainFronter mitm *mitm.Manager cache *ResponseCache directFailUntil map[string]time.Time mu sync.Mutex servers []net.Listener conns map[net.Conn]struct{} connMu sync.Mutex wg sync.WaitGroup ctx context.Context } func NewServer(cfg config.Config) (*Server, error) { host := cfg.GetString("listen_host", "127.0.0.1") port := cfg.GetInt("listen_port", 8080) socksEnabled := cfg.GetBool("socks5_enabled", true) socksHost := cfg.GetString("socks5_host", host) socksPort := cfg.GetInt("socks5_port", 1080) if socksEnabled && socksHost == host && socksPort == port { return nil, fmt.Errorf("listen_port and socks5_port must differ on the same host (both set to %d on %s)", port, host) } return &Server{ host: host, port: port, socksEnabled: socksEnabled, socksHost: socksHost, socksPort: socksPort, fronter: fronter.New(cfg), mitm: mitm.NewManager(), cache: NewResponseCache(constants.CacheMaxMB), directFailUntil: map[string]time.Time{}, conns: map[net.Conn]struct{}{}, }, nil } func (s *Server) Start(ctx context.Context) error { s.ctx = ctx ln, err := net.Listen("tcp", net.JoinHostPort(s.host, strconv.Itoa(s.port))) if err != nil { return err } s.servers = append(s.servers, ln) log.Infof("HTTP proxy listening on %s:%d", s.host, s.port) if s.socksEnabled { socksLn, err := net.Listen("tcp", net.JoinHostPort(s.socksHost, strconv.Itoa(s.socksPort))) if err != nil { log.Errorf("SOCKS5 listener failed on %s:%d: %v", s.socksHost, s.socksPort, err) } else { s.servers = append(s.servers, socksLn) log.Infof("SOCKS5 proxy listening on %s:%d", s.socksHost, s.socksPort) s.wg.Add(1) go func() { defer s.wg.Done() s.acceptLoop(socksLn, s.handleSocksConn) }() } } s.wg.Add(1) go func() { defer s.wg.Done() s.acceptLoop(ln, s.handleHTTPConn) }() <-ctx.Done() for _, l := range s.servers { _ = l.Close() } s.closeAllConns() _ = s.fronter.Close() s.wg.Wait() log.Infof("Server stopped") return nil } func (s *Server) acceptLoop(ln net.Listener, handler func(net.Conn)) { defer ln.Close() for { conn, err := ln.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return } continue } s.wg.Add(1) go func() { defer s.wg.Done() s.trackConn(conn) defer s.untrackConn(conn) handler(conn) }() } } func (s *Server) trackConn(conn net.Conn) { s.connMu.Lock() s.conns[conn] = struct{}{} s.connMu.Unlock() } func (s *Server) untrackConn(conn net.Conn) { s.connMu.Lock() delete(s.conns, conn) s.connMu.Unlock() } func (s *Server) closeAllConns() { s.connMu.Lock() conns := make([]net.Conn, 0, len(s.conns)) for conn := range s.conns { conns = append(conns, conn) } s.connMu.Unlock() for _, conn := range conns { _ = conn.Close() } } func (s *Server) handleHTTPConn(conn net.Conn) { defer conn.Close() conn.SetDeadline(time.Now().Add(30 * time.Second)) reader := bufio.NewReader(conn) line, err := reader.ReadString('\n') if err != nil { return } headers := []string{line} for { ln, err := reader.ReadString('\n') if err != nil { return } headers = append(headers, ln) if ln == "\r\n" || ln == "\n" { break } if sumLen(headers) > constants.MaxHeaderBytes { return } } parts := strings.Split(strings.TrimSpace(line), " ") if len(parts) < 2 { return } method := strings.ToUpper(parts[0]) if method == "CONNECT" { s.handleConnect(conn, reader, parts[1]) return } s.handlePlainHTTP(conn, reader, headers) } func (s *Server) handleConnect(conn net.Conn, reader *bufio.Reader, target string) { host, port := splitHostPort(target, 443) log.Infof("CONNECT -> %s:%d", host, port) _, _ = conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) s.handleTunnel(host, port, conn, reader) } func (s *Server) handleTunnel(host string, port int, conn net.Conn, reader *bufio.Reader) { if port == 443 { cfg, err := s.mitm.GetServerTLSConfig(host) if err != nil { return } tlsConn := tls.Server(conn, cfg) if err := tlsConn.Handshake(); err != nil { return } s.relayHTTPStream(host, port, tlsConn) return } s.relayHTTPStream(host, port, conn) } func (s *Server) relayHTTPStream(host string, port int, conn net.Conn) { reader := bufio.NewReader(conn) for { conn.SetDeadline(time.Now().Add(time.Duration(constants.ClientIdleTimeout) * time.Second)) line, err := reader.ReadString('\n') if err != nil { return } if line == "\r\n" || line == "\n" { continue } headers := []string{line} for { ln, err := reader.ReadString('\n') if err != nil { return } headers = append(headers, ln) if ln == "\r\n" || ln == "\n" { break } if sumLen(headers) > constants.MaxHeaderBytes { return } } method, path := parseRequestLine(line) body, err := readBody(reader, headers) if err != nil { return } headerMap := parseHeaders(headers[1:]) urlStr := normalizeURL(host, port, path) log.Infof("MITM -> %s %s", method, urlStr) origin := headerValue(headerMap, "origin") acrMethod := headerValue(headerMap, "access-control-request-method") acrHeaders := headerValue(headerMap, "access-control-request-headers") if strings.ToUpper(method) == "OPTIONS" && acrMethod != "" { resp := corsPreflight(origin, acrMethod, acrHeaders) _, _ = conn.Write(resp) continue } if s.cacheAllowed(method, urlStr, headerMap, body) { if cached := s.cache.Get(urlStr); cached != nil { if origin != "" { cached = injectCORSHeaders(cached, origin) } _, _ = conn.Write(cached) continue } } response := s.fronter.Relay(method, urlStr, headerMap, body) if s.cacheAllowed(method, urlStr, headerMap, body) { ttl := s.cache.ParseTTL(response, urlStr) if ttl > 0 { s.cache.Put(urlStr, response, ttl) } } if origin != "" { response = injectCORSHeaders(response, origin) } _, _ = conn.Write(response) } } func (s *Server) handlePlainHTTP(conn net.Conn, reader *bufio.Reader, headers []string) { method, path := parseRequestLine(headers[0]) body, err := readBody(reader, headers) if err != nil { return } headerMap := parseHeaders(headers[1:]) origin := headerValue(headerMap, "origin") acrMethod := headerValue(headerMap, "access-control-request-method") acrHeaders := headerValue(headerMap, "access-control-request-headers") if strings.ToUpper(method) == "OPTIONS" && acrMethod != "" { resp := corsPreflight(origin, acrMethod, acrHeaders) _, _ = conn.Write(resp) return } urlStr := path if s.cacheAllowed(method, urlStr, headerMap, body) { if cached := s.cache.Get(urlStr); cached != nil { if origin != "" { cached = injectCORSHeaders(cached, origin) } _, _ = conn.Write(cached) return } } response := s.fronter.Relay(method, urlStr, headerMap, body) if s.cacheAllowed(method, urlStr, headerMap, body) { ttl := s.cache.ParseTTL(response, urlStr) if ttl > 0 { s.cache.Put(urlStr, response, ttl) } } if origin != "" { response = injectCORSHeaders(response, origin) } _, _ = conn.Write(response) } func (s *Server) handleSocksConn(conn net.Conn) { defer conn.Close() conn.SetDeadline(time.Now().Add(15 * time.Second)) buf := make([]byte, 2) if _, err := io.ReadFull(conn, buf); err != nil { return } if buf[0] != 5 { return } methods := make([]byte, int(buf[1])) if _, err := io.ReadFull(conn, methods); err != nil { return } conn.Write([]byte{0x05, 0x00}) request := make([]byte, 4) if _, err := io.ReadFull(conn, request); err != nil { return } if request[0] != 5 || request[1] != 0x01 { conn.Write([]byte{0x05, 0x07, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) return } addrType := request[3] var host string switch addrType { case 0x01: ip := make([]byte, 4) if _, err := io.ReadFull(conn, ip); err != nil { return } host = net.IP(ip).String() case 0x03: ln := make([]byte, 1) if _, err := io.ReadFull(conn, ln); err != nil { return } name := make([]byte, int(ln[0])) if _, err := io.ReadFull(conn, name); err != nil { return } host = string(name) case 0x04: ip := make([]byte, 16) if _, err := io.ReadFull(conn, ip); err != nil { return } host = net.IP(ip).String() default: conn.Write([]byte{0x05, 0x08, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) return } portBuf := make([]byte, 2) if _, err := io.ReadFull(conn, portBuf); err != nil { return } port := int(portBuf[0])<<8 | int(portBuf[1]) log.Infof("SOCKS5 CONNECT -> %s:%d", host, port) conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0, 0, 0, 0, 0, 0}) s.handleTunnel(host, port, conn, bufio.NewReader(conn)) } func sumLen(lines []string) int { count := 0 for _, l := range lines { count += len(l) } return count } func parseRequestLine(line string) (string, string) { parts := strings.Split(strings.TrimSpace(line), " ") if len(parts) < 2 { return "GET", "/" } return parts[0], parts[1] } func parseHeaders(lines []string) map[string]string { h := map[string]string{} for _, ln := range lines { ln = strings.TrimRight(ln, "\r\n") if ln == "" { continue } parts := strings.SplitN(ln, ":", 2) if len(parts) != 2 { continue } key := textproto.CanonicalMIMEHeaderKey(strings.TrimSpace(parts[0])) val := strings.TrimSpace(parts[1]) h[key] = val } return h } func readBody(reader *bufio.Reader, headers []string) ([]byte, error) { cl := 0 for _, ln := range headers { if strings.HasPrefix(strings.ToLower(ln), "content-length:") { v := strings.TrimSpace(strings.TrimPrefix(ln, "Content-Length:")) n, err := strconv.Atoi(v) if err != nil || n < 0 { return nil, errors.New("invalid Content-Length") } cl = n } } if cl > constants.MaxRequestBodyBytes { return nil, errors.New("request body too large") } if cl == 0 { return nil, nil } buf := make([]byte, cl) _, err := io.ReadFull(reader, buf) return buf, err } func normalizeURL(host string, port int, path string) string { if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") { return path } scheme := "http" if port == 443 { scheme = "https" } if port == 80 || port == 443 { return fmt.Sprintf("%s://%s%s", scheme, host, path) } return fmt.Sprintf("%s://%s:%d%s", scheme, host, port, path) } func headerValue(headers map[string]string, name string) string { for k, v := range headers { if strings.ToLower(k) == name { return v } } return "" } func (s *Server) cacheAllowed(method, urlStr string, headers map[string]string, body []byte) bool { if strings.ToUpper(method) != "GET" || len(body) > 0 { return false } for _, name := range constants.UncacheableHeaderNames { if headerValue(headers, name) != "" { return false } } parsed, err := url.Parse(urlStr) if err != nil { return false } path := strings.ToLower(parsed.Path) for _, ext := range constants.StaticExts { if strings.HasSuffix(path, ext) { return true } } return false } func corsPreflight(origin, acrMethod, acrHeaders string) []byte { allowOrigin := origin if allowOrigin == "" { allowOrigin = "*" } allowMethods := "GET, POST, PUT, DELETE, PATCH, OPTIONS" if acrMethod != "" { allowMethods = acrMethod + ", " + allowMethods } allowHeaders := acrHeaders if allowHeaders == "" { allowHeaders = "*" } resp := "HTTP/1.1 204 No Content\r\n" + "Access-Control-Allow-Origin: " + allowOrigin + "\r\n" + "Access-Control-Allow-Methods: " + allowMethods + "\r\n" + "Access-Control-Allow-Headers: " + allowHeaders + "\r\n" + "Access-Control-Allow-Credentials: true\r\n" + "Access-Control-Max-Age: 86400\r\n" + "Vary: Origin\r\n" + "Content-Length: 0\r\n\r\n" return []byte(resp) } func injectCORSHeaders(response []byte, origin string) []byte { sep := []byte("\r\n\r\n") idx := bytes.Index(response, sep) if idx < 0 { return response } head := string(response[:idx]) body := response[idx+4:] lines := strings.Split(head, "\r\n") filtered := []string{} for _, ln := range lines { low := strings.ToLower(ln) if strings.HasPrefix(low, "access-control-") { continue } filtered = append(filtered, ln) } allowOrigin := origin if allowOrigin == "" { allowOrigin = "*" } filtered = append(filtered, "Access-Control-Allow-Origin: "+allowOrigin, "Access-Control-Allow-Credentials: true", "Access-Control-Allow-Methods: GET, POST, PUT, DELETE, PATCH, OPTIONS", "Access-Control-Allow-Headers: *", "Access-Control-Expose-Headers: *", "Vary: Origin", ) newHead := strings.Join(filtered, "\r\n") + "\r\n\r\n" return append([]byte(newHead), body...) } func splitHostPort(target string, defPort int) (string, int) { if strings.Contains(target, ":") { parts := strings.Split(target, ":") if len(parts) >= 2 { port, _ := strconv.Atoi(parts[len(parts)-1]) return strings.Join(parts[:len(parts)-1], ":"), port } } return target, defPort }