Skip to content
Snippets Groups Projects
Commit baceff89 authored by Alessio Caiazza's avatar Alessio Caiazza
Browse files

Merge branch 'jv-workhorse-http-server' into 'master'

Workhorse: simplify main http server

See merge request gitlab-org/gitlab!86767
parents f4288035 1916257b
No related branches found
No related tags found
No related merge requests found
package server
import (
"context"
"crypto/tls"
"crypto/x509"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
)
const (
certFile = "testdata/localhost.crt"
keyFile = "testdata/localhost.key"
)
func TestRun(t *testing.T) {
srv := defaultServer()
require.NoError(t, srv.Run())
defer srv.Close()
require.Len(t, srv.servers, 2)
clients := buildClients(t, srv.servers)
for url, client := range clients {
resp, err := client.Get(url)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
}
}
func TestShutdown(t *testing.T) {
ready := make(chan bool)
done := make(chan bool)
statusCodes := make(chan int)
srv := defaultServer()
srv.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ready <- true
<-done
rw.WriteHeader(200)
})
require.NoError(t, srv.Run())
defer srv.Close()
clients := buildClients(t, srv.servers)
for url, client := range clients {
go func(url string, client *http.Client) {
resp, err := client.Get(url)
require.NoError(t, err)
statusCodes <- resp.StatusCode
}(url, client)
}
for range clients {
<-ready
} // initiate requests
shutdownError := make(chan error)
go func() {
shutdownError <- srv.Shutdown(context.Background())
}()
for url, client := range clients {
require.Eventually(t, func() bool {
_, err := client.Get(url)
return err != nil
}, time.Second, 10*time.Millisecond, "server must stop accepting new requests")
}
for range clients {
done <- true
} // finish requests
require.NoError(t, <-shutdownError)
require.ElementsMatch(t, []int{200, 200}, []int{<-statusCodes, <-statusCodes})
}
func TestShutdown_withTimeout(t *testing.T) {
ready := make(chan bool)
done := make(chan bool)
srv := defaultServer()
srv.Handler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
ready <- true
<-done
rw.WriteHeader(200)
})
require.NoError(t, srv.Run())
defer srv.Close()
clients := buildClients(t, srv.servers)
for url, client := range clients {
go func(url string, client *http.Client) {
client.Get(url)
}(url, client)
}
for range clients {
<-ready
} // initiate requets
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
err := srv.Shutdown(ctx)
require.Error(t, err)
require.EqualError(t, err, "context deadline exceeded")
}
func defaultServer() Server {
return Server{
ListenerConfigs: []config.ListenerConfig{
{
Addr: "127.0.0.1:0",
Network: "tcp",
},
{
Addr: "127.0.0.1:0",
Network: "tcp",
Tls: &config.TlsConfig{
Certificate: certFile,
Key: keyFile,
},
},
},
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(200)
}),
Errors: make(chan error),
}
}
func buildClients(t *testing.T, servers []*http.Server) map[string]*http.Client {
httpsClient := &http.Client{}
certpool := x509.NewCertPool()
tlsCertificate, err := tls.LoadX509KeyPair(certFile, keyFile)
require.NoError(t, err)
certificate, err := x509.ParseCertificate(tlsCertificate.Certificate[0])
require.NoError(t, err)
certpool.AddCert(certificate)
httpsClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
}
httpServer, httpsServer := servers[0], servers[1]
return map[string]*http.Client{
"http://" + httpServer.Addr: http.DefaultClient,
"https://" + httpsServer.Addr: httpsClient,
}
}
package server
package main
 
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"syscall"
 
"gitlab.com/gitlab-org/labkit/log"
 
Loading
Loading
@@ -21,69 +17,7 @@ var tlsVersions = map[string]uint16{
"tls1.3": tls.VersionTLS13,
}
 
type Server struct {
Handler http.Handler
Umask int
ListenerConfigs []config.ListenerConfig
Errors chan error
servers []*http.Server
}
func (s *Server) Run() error {
oldUmask := syscall.Umask(s.Umask)
defer syscall.Umask(oldUmask)
for _, cfg := range s.ListenerConfigs {
listener, err := s.newListener("upstream", cfg)
if err != nil {
return fmt.Errorf("server.Run: failed creating a listener: %v", err)
}
s.runUpstreamServer(listener)
}
return nil
}
func (s *Server) Close() error {
return s.allServers(func(srv *http.Server) error { return srv.Close() })
}
func (s *Server) Shutdown(ctx context.Context) error {
return s.allServers(func(srv *http.Server) error { return srv.Shutdown(ctx) })
}
func (s *Server) allServers(callback func(*http.Server) error) error {
var resultErr error
errC := make(chan error, len(s.servers))
for _, server := range s.servers {
server := server // Capture loop variable
go func() { errC <- callback(server) }()
}
for range s.servers {
if err := <-errC; err != nil {
resultErr = err
}
}
return resultErr
}
func (s *Server) runUpstreamServer(listener net.Listener) {
srv := &http.Server{
Addr: listener.Addr().String(),
Handler: s.Handler,
}
go func() {
s.Errors <- srv.Serve(listener)
}()
s.servers = append(s.servers, srv)
}
func (s *Server) newListener(name string, cfg config.ListenerConfig) (net.Listener, error) {
func newListener(name string, cfg config.ListenerConfig) (net.Listener, error) {
if cfg.Tls == nil {
log.WithFields(log.Fields{"address": cfg.Addr, "network": cfg.Network}).Infof("Running %v server", name)
 
Loading
Loading
package main
import (
"crypto/tls"
"crypto/x509"
"io"
"net"
"os"
"testing"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/config"
)
func TestNewListener(t *testing.T) {
const unixSocket = "testdata/sock"
require.NoError(t, os.RemoveAll(unixSocket))
testCases := []struct {
network, addr string
}{
{"tcp", "127.0.0.1:0"},
{"unix", unixSocket},
}
for _, tc := range testCases {
t.Run(tc.network+"+"+tc.addr, func(t *testing.T) {
l, err := newListener("test", config.ListenerConfig{
Addr: tc.addr,
Network: tc.network,
})
require.NoError(t, err)
defer l.Close()
go pingServer(l)
c, err := net.Dial(tc.network, l.Addr().String())
require.NoError(t, err)
defer c.Close()
pingClient(t, c)
})
}
}
func pingServer(l net.Listener) {
c, err := l.Accept()
if err != nil {
return
}
io.WriteString(c, "ping")
c.Close()
}
func pingClient(t *testing.T, c net.Conn) {
t.Helper()
buf, err := io.ReadAll(c)
require.NoError(t, err)
require.Equal(t, "ping", string(buf))
}
func TestNewListener_TLS(t *testing.T) {
const (
certFile = "testdata/localhost.crt"
keyFile = "testdata/localhost.key"
)
cfg := config.ListenerConfig{Addr: "127.0.0.1:0",
Network: "tcp",
Tls: &config.TlsConfig{
Certificate: certFile,
Key: keyFile,
},
}
l, err := newListener("test", cfg)
require.NoError(t, err)
defer l.Close()
go pingServer(l)
tlsCertificate, err := tls.LoadX509KeyPair(certFile, keyFile)
require.NoError(t, err)
certificate, err := x509.ParseCertificate(tlsCertificate.Certificate[0])
require.NoError(t, err)
certpool := x509.NewCertPool()
certpool.AddCert(certificate)
c, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{RootCAs: certpool})
require.NoError(t, err)
defer c.Close()
pingClient(t, c)
}
Loading
Loading
@@ -23,7 +23,6 @@ import (
"gitlab.com/gitlab-org/gitlab/workhorse/internal/queueing"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/redis"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/secret"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/server"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/upstream"
)
 
Loading
Loading
@@ -241,14 +240,20 @@ func run(boot bootConfig, cfg config.Config) error {
Network: boot.listenNetwork,
Addr: boot.listenAddr,
}
srv := &server.Server{
Handler: up,
Umask: boot.listenUmask,
ListenerConfigs: append(cfg.Listeners, listenerFromBootConfig),
Errors: finalErrors,
var listeners []net.Listener
oldUmask := syscall.Umask(boot.listenUmask)
for _, cfg := range append(cfg.Listeners, listenerFromBootConfig) {
l, err := newListener("upstream", cfg)
if err != nil {
return err
}
listeners = append(listeners, l)
}
if err := srv.Run(); err != nil {
return fmt.Errorf("running server: %v", err)
syscall.Umask(oldUmask)
srv := &http.Server{Handler: up}
for _, l := range listeners {
go func(l net.Listener) { finalErrors <- srv.Serve(l) }(l)
}
 
select {
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment