// ============================================================================== // MasterHttpRelayVPN // Author: MasterkinG32 // Github: https://github.com/masterking32 // Year: 2026 // ============================================================================== package client import ( "context" "encoding/binary" "errors" "fmt" "io" "net" "slices" "strconv" "time" ) const ( socksVersion5 = 0x05 socksMethodNoAuth = 0x00 socksMethodUserPass = 0x02 socksMethodNoAcceptable = 0xFF socksCmdConnect = 0x01 socksAtypIPv4 = 0x01 socksAtypDomain = 0x03 socksAtypIPv6 = 0x04 socksReplySuccess = 0x00 socksReplyGeneralFailure = 0x01 socksReplyCommandUnsupported = 0x07 socksReplyAddressUnsupported = 0x08 socksUserPassVersion = 0x01 socksAuthSuccess = 0x00 socksAuthFailure = 0x01 ) func (c *Client) handleConn(ctx context.Context, conn net.Conn) { c.registerConn(conn) defer c.unregisterConn(conn) defer conn.Close() session := c.sessions.New(conn.RemoteAddr().String()) defer c.sessions.Delete(session.ID) c.log.Infof("accepted client %s session=%d", conn.RemoteAddr(), session.ID) if err := c.handleSOCKS5(ctx, conn, session); err != nil { c.log.Errorf("session=%d closed: %v", session.ID, err) return } } func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn, session *Session) error { version := make([]byte, 1) if _, err := io.ReadFull(conn, version); err != nil { return err } if version[0] != socksVersion5 { return fmt.Errorf("unsupported SOCKS version: %d", version[0]) } method, err := c.negotiateAuth(conn, session) if err != nil { return err } if method == socksMethodUserPass { if err := c.handleUserPassAuth(conn, session); err != nil { return err } } targetHost, targetPort, atyp, err := readConnectRequest(conn) if err != nil { return err } session.TargetHost = targetHost session.TargetPort = targetPort session.AddressType = atyp session.ConnectAccepted = true session.HandshakeDone = true session.LastActivityAt = time.Now() if err := writeSocksReply(conn, socksReplySuccess); err != nil { return err } c.log.Infof( "session=%d CONNECT target=%s:%d auth_method=%d", session.ID, session.TargetHost, session.TargetPort, session.AuthMethod, ) return c.captureInitialPayload(ctx, conn, session) } func (c *Client) negotiateAuth(conn net.Conn, session *Session) (byte, error) { countBuf := make([]byte, 1) if _, err := io.ReadFull(conn, countBuf); err != nil { return 0, err } methodCount := int(countBuf[0]) methods := make([]byte, methodCount) if _, err := io.ReadFull(conn, methods); err != nil { return 0, err } selected := byte(socksMethodNoAcceptable) if c.cfg.SOCKSAuth { if slices.Contains(methods, socksMethodUserPass) { selected = socksMethodUserPass } } else { if slices.Contains(methods, socksMethodNoAuth) { selected = socksMethodNoAuth } } if _, err := conn.Write([]byte{socksVersion5, selected}); err != nil { return 0, err } if selected == socksMethodNoAcceptable { return 0, errors.New("no acceptable auth method") } session.AuthMethod = selected return selected, nil } func (c *Client) handleUserPassAuth(conn net.Conn, session *Session) error { header := make([]byte, 2) if _, err := io.ReadFull(conn, header); err != nil { return err } if header[0] != socksUserPassVersion { return fmt.Errorf("invalid username/password auth version: %d", header[0]) } username := make([]byte, int(header[1])) if _, err := io.ReadFull(conn, username); err != nil { return err } passLen := make([]byte, 1) if _, err := io.ReadFull(conn, passLen); err != nil { return err } password := make([]byte, int(passLen[0])) if _, err := io.ReadFull(conn, password); err != nil { return err } ok := string(username) == c.cfg.SOCKSUsername && string(password) == c.cfg.SOCKSPassword session.UsernameUsed = string(username) if ok { _, err := conn.Write([]byte{socksUserPassVersion, socksAuthSuccess}) return err } _, _ = conn.Write([]byte{socksUserPassVersion, socksAuthFailure}) return errors.New("invalid SOCKS username/password") } func readConnectRequest(conn net.Conn) (string, uint16, byte, error) { header := make([]byte, 4) if _, err := io.ReadFull(conn, header); err != nil { return "", 0, 0, err } if header[0] != socksVersion5 { return "", 0, 0, fmt.Errorf("invalid request version: %d", header[0]) } if header[1] != socksCmdConnect { _ = writeSocksReply(conn, socksReplyCommandUnsupported) return "", 0, 0, fmt.Errorf("unsupported SOCKS command: %d", header[1]) } if header[2] != 0x00 { return "", 0, 0, errors.New("non-zero reserved byte in SOCKS request") } atyp := header[3] host, err := readTargetHost(conn, atyp) if err != nil { _ = writeSocksReply(conn, socksReplyAddressUnsupported) return "", 0, 0, err } portBytes := make([]byte, 2) if _, err := io.ReadFull(conn, portBytes); err != nil { return "", 0, 0, err } return host, binary.BigEndian.Uint16(portBytes), atyp, nil } func readTargetHost(conn net.Conn, atyp byte) (string, error) { switch atyp { case socksAtypIPv4: ip := make([]byte, 4) if _, err := io.ReadFull(conn, ip); err != nil { return "", err } return net.IP(ip).String(), nil case socksAtypIPv6: ip := make([]byte, 16) if _, err := io.ReadFull(conn, ip); err != nil { return "", err } return net.IP(ip).String(), nil case socksAtypDomain: size := make([]byte, 1) if _, err := io.ReadFull(conn, size); err != nil { return "", err } domain := make([]byte, int(size[0])) if _, err := io.ReadFull(conn, domain); err != nil { return "", err } return string(domain), nil default: return "", fmt.Errorf("unsupported address type: %d", atyp) } } func writeSocksReply(conn net.Conn, reply byte) error { resp := []byte{ socksVersion5, reply, 0x00, socksAtypIPv4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } _, err := conn.Write(resp) return err } func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, session *Session) error { peekTimeout := 2 * time.Second idleTimeout := 30 * time.Second buf := make([]byte, 32*1024) if err := conn.SetReadDeadline(time.Now().Add(peekTimeout)); err != nil { return err } n, err := conn.Read(buf) if err == nil && n > 0 { session.InitialPayload = append([]byte(nil), buf[:n]...) session.BytesCaptured += n session.LastActivityAt = time.Now() c.log.Infof( "session=%d captured initial payload bytes=%d target=%s", session.ID, n, net.JoinHostPort(session.TargetHost, strconv.Itoa(int(session.TargetPort))), ) } else if ne, ok := err.(net.Error); !ok || !ne.Timeout() { if errors.Is(err, io.EOF) { return nil } if err != nil { return err } } for { select { case <-ctx.Done(): return nil default: } if err := conn.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil { return err } n, err := conn.Read(buf) if n > 0 { session.BytesCaptured += n session.LastActivityAt = time.Now() c.log.Debugf("session=%d buffered payload chunk=%d total=%d", session.ID, n, session.BytesCaptured) } if err != nil { if errors.Is(err, io.EOF) { return nil } if ne, ok := err.(net.Error); ok && ne.Timeout() { return nil } return err } } }