package main import ( "bufio" "bytes" "crypto/tls" "embed" "flag" "fmt" "io" "log" "net" "net/http" "os" "runtime" "strings" "sync" "time" "github.com/fsnotify/fsnotify" ) //go:embed VERSION var versionFile embed.FS var ( NAME = "tcproxy" VERSION string whitelist = make(map[string]struct{}) whitelistFile = "whitelist.txt" mu sync.RWMutex ) func init() { var err error VERSION, err = readVersion() if err != nil { VERSION = "xx.yy.zz" } } func printHelp() { helpText := fmt.Sprintf(`Usage: %s [options] [path] Options: -b, --bind The address to bind the server (default "localhost:8443"). -v, --version Display the version of the server. -h, --help Display this help message. `, NAME) fmt.Print(helpText) } func readVersion() (string, error) { version, err := versionFile.ReadFile("VERSION") if err != nil { return "", err } return strings.TrimSpace(string(version)), nil } func printVersion() { fmt.Printf( "%s %s %s %s %s\n", NAME, VERSION, runtime.Version(), runtime.GOOS, runtime.GOARCH, ) } type readOnlyConn struct { reader io.Reader } func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } func (conn readOnlyConn) Close() error { return nil } func (conn readOnlyConn) LocalAddr() net.Addr { return nil } func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } func main() { versionFlag := flag.Bool("version", false, "Display the version of the server") versionFlagShort := flag.Bool("v", false, "Display the version") helpFlag := flag.Bool("help", false, "Display help message") helpFlagShort := flag.Bool("h", false, "Display help message") bindAddr := flag.String("bind", "localhost:8443", "The address to bind the server") bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server") flag.Parse() if *bindAddrShort != "localhost:8443" { bindAddr = bindAddrShort } if *versionFlag || *versionFlagShort { printVersion() os.Exit(0) } if *helpFlag || *helpFlagShort { printHelp() os.Exit(0) } err := loadWhitelist() if err != nil { log.Fatalf("Failed to load whitelist: %v", err) } go watchWhitelist() l, err := net.Listen("tcp", *bindAddr) if err != nil { log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err) } printVersion() log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr) for { conn, err := l.Accept() if err != nil { log.Printf("Failed to accept connection: %v", err) continue } go handleConnection(conn) } } func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) { peekedBytes := new(bytes.Buffer) hello, err := readClientHello(io.TeeReader(reader, peekedBytes)) if err != nil { return nil, nil, err } return hello, io.MultiReader(peekedBytes, reader), nil } func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { var hello *tls.ClientHelloInfo err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { hello = new(tls.ClientHelloInfo) *hello = *argHello return nil, nil }, }).Handshake() if hello == nil { return nil, err } return hello, nil } func handleHTTPS(reader io.Reader, frontendConn net.Conn) { clientHello, clientReader, err := peekClientHello(reader) if err != nil { log.Print("Failed to peek ClientHello:", err) return } serverName := clientHello.ServerName log.Printf("Forwarding request for domain: %s", serverName) // Connect to the backend server as specified by the SNI backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(serverName, "443"), 10*time.Second) if err != nil { log.Printf("Failed to connect to backend %s: %v", serverName, err) return } defer backendConn.Close() var wg sync.WaitGroup wg.Add(2) // Forward traffic from client to backend go func() { io.Copy(backendConn, clientReader) backendConn.(*net.TCPConn).CloseWrite() wg.Done() }() // Forward traffic from backend to client go func() { io.Copy(frontendConn, backendConn) frontendConn.(*net.TCPConn).CloseWrite() wg.Done() }() wg.Wait() log.Printf("Completed forwarding for domain: %s", serverName) } func handleHTTP(reader io.Reader, frontendConn net.Conn) { // Buffer the reader to peek into the HTTP request bufferedReader := bufio.NewReader(reader) req, err := http.ReadRequest(bufferedReader) if err != nil { log.Printf("Error reading HTTP request: %v", err) return } // Extract the Host from the HTTP request backendServer := req.Host if backendServer == "" { log.Println("No Host header found or empty Host header") return } // Establish a connection to the backend server backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(backendServer, "80"), 10*time.Second) if err != nil { log.Printf("Failed to connect to backend %s: %v", backendServer, err) return } defer backendConn.Close() // Write the original request to the backend req.Write(backendConn) // Use a WaitGroup to wait for both copy operations to complete var wg sync.WaitGroup wg.Add(2) // Forward remaining traffic from client to backend go func() { io.Copy(backendConn, bufferedReader) // bufferedReader now points to the rest of the stream after the initial read backendConn.(*net.TCPConn).CloseWrite() wg.Done() }() // Forward traffic from backend to client go func() { io.Copy(frontendConn, backendConn) frontendConn.(*net.TCPConn).CloseWrite() wg.Done() }() // Wait for both forwarding operations to complete wg.Wait() log.Printf("Completed forwarding HTTP traffic to %s", backendServer) } func handleConnection(clientConn net.Conn) { defer clientConn.Close() tcpAddr, ok := clientConn.RemoteAddr().(*net.TCPAddr) if !ok { log.Printf("Failed to obtain TCP address from connection") clientConn.Close() return } clientIP := tcpAddr.IP.String() mu.RLock() _, allowed := whitelist[clientIP] mu.RUnlock() if !allowed { log.Printf("Connection from %s denied", clientIP) return } log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port) // Create a buffer to peek the first byte buf := make([]byte, 1) _, err := clientConn.Read(buf) if err != nil { log.Print("Failed to read from client connection:", err) return } // Use a MultiReader to put the peeked byte back in front of the clientConn stream clientReader := io.MultiReader(bytes.NewReader(buf), clientConn) // Check if the first byte indicates a TLS handshake (0x16) if buf[0] == 0x16 { // Handle as HTTPS handleHTTPS(clientReader, clientConn) } else { // Handle as HTTP or other non-TLS traffic handleHTTP(clientReader, clientConn) } } func loadWhitelist() error { file, err := os.Open(whitelistFile) if err != nil { return err } defer file.Close() newWhitelist := make(map[string]struct{}) scanner := bufio.NewScanner(file) for scanner.Scan() { ip := strings.TrimSpace(scanner.Text()) if ip != "" && !strings.HasPrefix(ip, "#") { newWhitelist[ip] = struct{}{} } } if err := scanner.Err(); err != nil { return err } mu.Lock() whitelist = newWhitelist mu.Unlock() log.Printf("Whitelist updated: %v", newWhitelist) return nil } func watchWhitelist() { watcher, err := fsnotify.NewWatcher() if err != nil { log.Fatalf("Failed to create file watcher: %v", err) } defer watcher.Close() err = watcher.Add(whitelistFile) if err != nil { log.Fatalf("Failed to watch file: %v", err) } for { select { case event, ok := <-watcher.Events: if !ok { return } if event.Op&fsnotify.Write == fsnotify.Write { log.Println("Whitelist file modified") if err := loadWhitelist(); err != nil { log.Printf("Failed to reload whitelist: %v", err) } } case err, ok := <-watcher.Errors: if !ok { return } log.Printf("File watcher error: %v", err) } } }