160 lines
3.1 KiB
Go
160 lines
3.1 KiB
Go
package tests_test
|
|
|
|
import (
|
|
"crypto/ed25519"
|
|
"io"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type execSSHServer struct {
|
|
listener net.Listener
|
|
config *ssh.ServerConfig
|
|
|
|
rejectFirstConnections atomic.Int32
|
|
successfulConnections atomic.Int32
|
|
acceptedSessions atomic.Int32
|
|
releaseSessions <-chan struct{}
|
|
done chan struct{}
|
|
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
func startExecSSHServerWithSessionGate(
|
|
t *testing.T,
|
|
rejectFirstConnections int32,
|
|
releaseSessions <-chan struct{},
|
|
) *execSSHServer {
|
|
t.Helper()
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
_, privateKey, err := ed25519.GenerateKey(nil)
|
|
require.NoError(t, err)
|
|
|
|
signer, err := ssh.NewSignerFromKey(privateKey)
|
|
require.NoError(t, err)
|
|
|
|
server := &execSSHServer{
|
|
listener: listener,
|
|
releaseSessions: releaseSessions,
|
|
done: make(chan struct{}),
|
|
config: &ssh.ServerConfig{
|
|
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
|
if conn.User() != "admin" || string(password) != "admin" {
|
|
return nil, ssh.ErrNoAuth
|
|
}
|
|
|
|
return &ssh.Permissions{}, nil
|
|
},
|
|
},
|
|
}
|
|
server.rejectFirstConnections.Store(rejectFirstConnections)
|
|
server.config.AddHostKey(signer)
|
|
|
|
server.wg.Add(1)
|
|
go server.run()
|
|
|
|
t.Cleanup(func() {
|
|
close(server.done)
|
|
require.NoError(t, server.listener.Close())
|
|
server.wg.Wait()
|
|
})
|
|
|
|
return server
|
|
}
|
|
|
|
func (server *execSSHServer) Addr() string {
|
|
return server.listener.Addr().String()
|
|
}
|
|
|
|
func (server *execSSHServer) run() {
|
|
defer server.wg.Done()
|
|
|
|
for {
|
|
conn, err := server.listener.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if server.rejectFirstConnections.Add(-1) >= 0 {
|
|
_ = conn.Close()
|
|
|
|
continue
|
|
}
|
|
|
|
server.wg.Add(1)
|
|
go func() {
|
|
defer server.wg.Done()
|
|
|
|
server.serve(conn)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func (server *execSSHServer) serve(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
serverConn, newChannels, requests, err := ssh.NewServerConn(conn, server.config)
|
|
if err != nil {
|
|
return
|
|
}
|
|
server.successfulConnections.Add(1)
|
|
defer serverConn.Close()
|
|
|
|
go ssh.DiscardRequests(requests)
|
|
|
|
for newChannel := range newChannels {
|
|
if newChannel.ChannelType() != "session" {
|
|
_ = newChannel.Reject(ssh.UnknownChannelType, "unsupported channel type")
|
|
|
|
continue
|
|
}
|
|
|
|
channel, requests, err := newChannel.Accept()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
server.acceptedSessions.Add(1)
|
|
|
|
server.wg.Add(1)
|
|
go func() {
|
|
defer server.wg.Done()
|
|
|
|
server.serveExecSSHSession(channel, requests)
|
|
}()
|
|
}
|
|
}
|
|
|
|
func (server *execSSHServer) serveExecSSHSession(channel ssh.Channel, requests <-chan *ssh.Request) {
|
|
defer channel.Close()
|
|
|
|
for request := range requests {
|
|
switch request.Type {
|
|
case "exec":
|
|
_ = request.Reply(true, nil)
|
|
_, _ = io.WriteString(channel, "ok")
|
|
if server.releaseSessions != nil {
|
|
select {
|
|
case <-server.releaseSessions:
|
|
case <-server.done:
|
|
return
|
|
}
|
|
}
|
|
_, _ = channel.SendRequest("exit-status", false, ssh.Marshal(struct {
|
|
Status uint32
|
|
}{Status: 0}))
|
|
|
|
return
|
|
default:
|
|
_ = request.Reply(false, nil)
|
|
}
|
|
}
|
|
}
|