tcproxy/main.go
2024-08-06 22:58:45 +03:30

348 lines
8.1 KiB
Go

package main
import (
"bufio"
"bytes"
"crypto/tls"
"embed"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"runtime"
"strings"
"sync"
"time"
"github.com/fsnotify/fsnotify"
)
//go:embed VERSION
var versionFile embed.FS
var (
NAME = "tcproxy"
VERSION string
whitelist = make(map[string]struct{})
whitelistFile = "whitelist.txt"
mu sync.RWMutex
)
func init() {
var err error
VERSION, err = readVersion()
if err != nil {
VERSION = "xx.yy.zz"
}
}
func printHelp() {
helpText := fmt.Sprintf(`Usage: %s [options] [path]
Options:
-b, --bind The address to bind the server (default "localhost:8443").
-v, --version Display the version of the server.
-h, --help Display this help message.
`, NAME)
fmt.Print(helpText)
}
func readVersion() (string, error) {
version, err := versionFile.ReadFile("VERSION")
if err != nil {
return "", err
}
return strings.TrimSpace(string(version)), nil
}
func printVersion() {
fmt.Printf(
"%s %s %s %s %s\n",
NAME,
VERSION,
runtime.Version(),
runtime.GOOS,
runtime.GOARCH,
)
}
type readOnlyConn struct {
reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
func main() {
versionFlag := flag.Bool("version", false, "Display the version of the server")
versionFlagShort := flag.Bool("v", false, "Display the version")
helpFlag := flag.Bool("help", false, "Display help message")
helpFlagShort := flag.Bool("h", false, "Display help message")
bindAddr := flag.String("bind", "localhost:8443", "The address to bind the server")
bindAddrShort := flag.String("b", "localhost:8443", "The address to bind the server")
flag.Parse()
if *bindAddrShort != "localhost:8443" {
bindAddr = bindAddrShort
}
if *versionFlag || *versionFlagShort {
printVersion()
os.Exit(0)
}
if *helpFlag || *helpFlagShort {
printHelp()
os.Exit(0)
}
err := loadWhitelist()
if err != nil {
log.Fatalf("Failed to load whitelist: %v", err)
}
go watchWhitelist()
l, err := net.Listen("tcp", *bindAddr)
if err != nil {
log.Fatalf("Failed to bind to address %s: %v", *bindAddr, err)
}
printVersion()
log.Printf("Listening on %s for incoming HTTPS (or HTTP) connections...", *bindAddr)
for {
conn, err := l.Accept()
if err != nil {
log.Printf("Failed to accept connection: %v", err)
continue
}
go handleConnection(conn)
}
}
func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
hello, err := readClientHello(io.TeeReader(reader, peekedBytes))
if err != nil {
return nil, nil, err
}
return hello, io.MultiReader(peekedBytes, reader), nil
}
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
var hello *tls.ClientHelloInfo
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
if hello == nil {
return nil, err
}
return hello, nil
}
func handleHTTPS(reader io.Reader, frontendConn net.Conn) {
clientHello, clientReader, err := peekClientHello(reader)
if err != nil {
log.Print("Failed to peek ClientHello:", err)
return
}
serverName := clientHello.ServerName
log.Printf("Forwarding request for domain: %s", serverName)
// Connect to the backend server as specified by the SNI
backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(serverName, "443"), 10*time.Second)
if err != nil {
log.Printf("Failed to connect to backend %s: %v", serverName, err)
return
}
defer backendConn.Close()
var wg sync.WaitGroup
wg.Add(2)
// Forward traffic from client to backend
go func() {
io.Copy(backendConn, clientReader)
backendConn.(*net.TCPConn).CloseWrite()
wg.Done()
}()
// Forward traffic from backend to client
go func() {
io.Copy(frontendConn, backendConn)
frontendConn.(*net.TCPConn).CloseWrite()
wg.Done()
}()
wg.Wait()
log.Printf("Completed forwarding for domain: %s", serverName)
}
func handleHTTP(reader io.Reader, frontendConn net.Conn) {
// Buffer the reader to peek into the HTTP request
bufferedReader := bufio.NewReader(reader)
req, err := http.ReadRequest(bufferedReader)
if err != nil {
log.Printf("Error reading HTTP request: %v", err)
return
}
// Extract the Host from the HTTP request
backendServer := req.Host
if backendServer == "" {
log.Println("No Host header found or empty Host header")
return
}
// Establish a connection to the backend server
backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(backendServer, "80"), 10*time.Second)
if err != nil {
log.Printf("Failed to connect to backend %s: %v", backendServer, err)
return
}
defer backendConn.Close()
// Write the original request to the backend
req.Write(backendConn)
// Use a WaitGroup to wait for both copy operations to complete
var wg sync.WaitGroup
wg.Add(2)
// Forward remaining traffic from client to backend
go func() {
io.Copy(backendConn, bufferedReader) // bufferedReader now points to the rest of the stream after the initial read
backendConn.(*net.TCPConn).CloseWrite()
wg.Done()
}()
// Forward traffic from backend to client
go func() {
io.Copy(frontendConn, backendConn)
frontendConn.(*net.TCPConn).CloseWrite()
wg.Done()
}()
// Wait for both forwarding operations to complete
wg.Wait()
log.Printf("Completed forwarding HTTP traffic to %s", backendServer)
}
func handleConnection(clientConn net.Conn) {
defer clientConn.Close()
tcpAddr, ok := clientConn.RemoteAddr().(*net.TCPAddr)
if !ok {
log.Printf("Failed to obtain TCP address from connection")
clientConn.Close()
return
}
clientIP := tcpAddr.IP.String()
mu.RLock()
_, allowed := whitelist[clientIP]
mu.RUnlock()
if !allowed {
log.Printf("Connection from %s denied", clientIP)
return
}
log.Printf("Received connection from %s:%d", tcpAddr.IP, tcpAddr.Port)
// Create a buffer to peek the first byte
buf := make([]byte, 1)
_, err := clientConn.Read(buf)
if err != nil {
log.Print("Failed to read from client connection:", err)
return
}
// Use a MultiReader to put the peeked byte back in front of the clientConn stream
clientReader := io.MultiReader(bytes.NewReader(buf), clientConn)
// Check if the first byte indicates a TLS handshake (0x16)
if buf[0] == 0x16 {
// Handle as HTTPS
handleHTTPS(clientReader, clientConn)
} else {
// Handle as HTTP or other non-TLS traffic
handleHTTP(clientReader, clientConn)
}
}
func loadWhitelist() error {
file, err := os.Open(whitelistFile)
if err != nil {
return err
}
defer file.Close()
newWhitelist := make(map[string]struct{})
scanner := bufio.NewScanner(file)
for scanner.Scan() {
ip := strings.TrimSpace(scanner.Text())
if ip != "" && !strings.HasPrefix(ip, "#") {
newWhitelist[ip] = struct{}{}
}
}
if err := scanner.Err(); err != nil {
return err
}
mu.Lock()
whitelist = newWhitelist
mu.Unlock()
log.Printf("Whitelist updated: %v", newWhitelist)
return nil
}
func watchWhitelist() {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatalf("Failed to create file watcher: %v", err)
}
defer watcher.Close()
err = watcher.Add(whitelistFile)
if err != nil {
log.Fatalf("Failed to watch file: %v", err)
}
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
log.Println("Whitelist file modified")
if err := loadWhitelist(); err != nil {
log.Printf("Failed to reload whitelist: %v", err)
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
log.Printf("File watcher error: %v", err)
}
}
}