orchard/internal/tests/exec_ssh_server_test.go

199 lines
3.9 KiB
Go

package tests_test
import (
"crypto/ed25519"
"io"
"net"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)
type execSSHServer struct {
listener net.Listener
config *ssh.ServerConfig
rejectFirstConnections atomic.Int32
rejectFirstSessions atomic.Int32
successfulConnections atomic.Int32
keepaliveRequests atomic.Int32
mu sync.Mutex
conns map[net.Conn]struct{}
wg sync.WaitGroup
}
func startExecSSHServer(t *testing.T, rejectFirstConnections int32) *execSSHServer {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
_, privateKey, err := ed25519.GenerateKey(nil)
require.NoError(t, err)
signer, err := ssh.NewSignerFromKey(privateKey)
require.NoError(t, err)
server := &execSSHServer{
listener: listener,
conns: map[net.Conn]struct{}{},
config: &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if conn.User() != "admin" || string(password) != "admin" {
return nil, ssh.ErrNoAuth
}
return &ssh.Permissions{}, nil
},
},
}
server.rejectFirstConnections.Store(rejectFirstConnections)
server.config.AddHostKey(signer)
server.wg.Add(1)
go server.run()
t.Cleanup(func() {
require.NoError(t, server.listener.Close())
server.wg.Wait()
})
return server
}
func (server *execSSHServer) Addr() string {
return server.listener.Addr().String()
}
func (server *execSSHServer) SuccessfulConnections() int32 {
return server.successfulConnections.Load()
}
func (server *execSSHServer) KeepaliveRequests() int32 {
return server.keepaliveRequests.Load()
}
func (server *execSSHServer) RejectNextSessions(count int32) {
server.rejectFirstSessions.Store(count)
}
func (server *execSSHServer) CloseClientConnections() {
server.mu.Lock()
conns := make([]net.Conn, 0, len(server.conns))
for conn := range server.conns {
conns = append(conns, conn)
}
server.mu.Unlock()
for _, conn := range conns {
_ = conn.Close()
}
}
func (server *execSSHServer) run() {
defer server.wg.Done()
for {
conn, err := server.listener.Accept()
if err != nil {
return
}
if server.rejectFirstConnections.Add(-1) >= 0 {
_ = conn.Close()
continue
}
server.wg.Add(1)
go func() {
defer server.wg.Done()
server.serve(conn)
}()
}
}
func (server *execSSHServer) serve(conn net.Conn) {
server.mu.Lock()
server.conns[conn] = struct{}{}
server.mu.Unlock()
defer func() {
server.mu.Lock()
delete(server.conns, conn)
server.mu.Unlock()
}()
defer conn.Close()
serverConn, newChannels, requests, err := ssh.NewServerConn(conn, server.config)
if err != nil {
return
}
defer serverConn.Close()
server.successfulConnections.Add(1)
go server.serveGlobalRequests(requests)
for newChannel := range newChannels {
if newChannel.ChannelType() != "session" {
_ = newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type")
continue
}
if server.rejectFirstSessions.Add(-1) >= 0 {
_ = newChannel.Reject(ssh.Prohibited, "session rejected for test")
continue
}
channel, requests, err := newChannel.Accept()
if err != nil {
continue
}
server.wg.Add(1)
go func() {
defer server.wg.Done()
serveExecSSHSession(channel, requests)
}()
}
}
func (server *execSSHServer) serveGlobalRequests(requests <-chan *ssh.Request) {
for request := range requests {
if request.Type == "keepalive@openssh.com" {
server.keepaliveRequests.Add(1)
}
if request.WantReply {
_ = request.Reply(false, nil)
}
}
}
func serveExecSSHSession(channel ssh.Channel, requests <-chan *ssh.Request) {
defer channel.Close()
for request := range requests {
switch request.Type {
case "exec":
_ = request.Reply(true, nil)
_, _ = io.WriteString(channel, "ok")
_, _ = channel.SendRequest("exit-status", false, ssh.Marshal(struct {
Status uint32
}{Status: 0}))
return
default:
_ = request.Reply(false, nil)
}
}
}