Skip to content
Snippets Groups Projects
Commit 773f25b7 authored by Jacob Vosmaer (GitLab)'s avatar Jacob Vosmaer (GitLab)
Browse files

Simplify request buffering

parent 24248b3e
No related branches found
No related tags found
1 merge request!62WIP: Add option to buffer proxy requests
Pipeline #
Loading
Loading
@@ -22,15 +22,20 @@ func (_ *emptyReadCloser) Close() error {
}
 
type requestBuffer struct {
dynamicBufferSize uint
dynamicBufferSize int
handler http.Handler
}
 
func New(size uint, h http.Handler) http.Handler {
func New(size int, h http.Handler) http.Handler {
return &requestBuffer{dynamicBufferSize: size, handler: h}
}
 
func (b *requestBuffer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
b.handler.ServeHTTP(w, r)
return
}
body, err := b.buffer(r.Body)
if err != nil {
helper.Fail500(w, r, fmt.Errorf("requestBuffer.buffer: %v", err))
Loading
Loading
@@ -47,55 +52,20 @@ func (b *requestBuffer) buffer(body io.ReadCloser) (io.ReadCloser, error) {
return &emptyReadCloser{}, nil
}
 
peekBuffer, err := staticBuffer(body)
if err != nil {
return nil, err
}
if len(peekBuffer) == 0 {
return &emptyReadCloser{}, nil
}
memoryBuffer, done, err := b.dynamicBufferWithPrefix(body, peekBuffer)
buffer := bytes.NewBuffer(make([]byte, b.dynamicBufferSize))
_, err := io.Copy(buffer, io.LimitReader(body, int64(b.dynamicBufferSize)))
if err != nil {
return nil, err
}
if done {
return ioutil.NopCloser(bytes.NewReader(memoryBuffer)), nil
}
return fileBufferWithPrefix(body, memoryBuffer)
}
func staticBuffer(body io.Reader) ([]byte, error) {
peekArray := [1]byte{}
peekBuffer := peekArray[:]
peeked, err := body.Read(peekBuffer)
if err != nil && err != io.EOF {
return nil, err
}
return peekBuffer[:peeked], nil
}
func (b *requestBuffer) dynamicBufferWithPrefix(body io.Reader, prefix []byte) ([]byte, bool, error) {
smallBuffer := make([]byte, b.dynamicBufferSize)
copy(smallBuffer, prefix)
buffered := len(prefix)
 
n, err := io.Copy(
bytes.NewBuffer(smallBuffer[buffered:]),
io.LimitReader(body, int64(len(smallBuffer)-buffered)),
)
if err != nil && err != io.EOF {
return nil, false, err
if buffer.Len() < b.dynamicBufferSize {
return ioutil.NopCloser(buffer), nil
}
buffered += int(n) // assume len(smallBuffer) fits in an int
 
return smallBuffer[:buffered], buffered < len(smallBuffer), nil
return fileBufferWithPrefix(body, buffer)
}
 
func fileBufferWithPrefix(body io.Reader, prefix []byte) (io.ReadCloser, error) {
func fileBufferWithPrefix(body io.Reader, prefix io.Reader) (io.ReadCloser, error) {
tempFile, err := ioutil.TempFile("", "gitlab-workhorse-request-body")
if err != nil {
return nil, err
Loading
Loading
@@ -105,7 +75,7 @@ func fileBufferWithPrefix(body io.Reader, prefix []byte) (io.ReadCloser, error)
return nil, err
}
 
if n, err := io.Copy(tempFile, bytes.NewReader(prefix)); err != nil || n != int64(len(prefix)) {
if _, err := io.Copy(tempFile, prefix); err != nil {
return nil, err
}
 
Loading
Loading
package requestbuffer
import (
"bytes"
"io/ioutil"
"testing"
)
func TestBuffer(t *testing.T) {
cases := []string{
"",
"0",
"01234",
"0123456789",
}
for _, c := range cases {
rb := &requestBuffer{dynamicBufferSize: 5}
result, err := rb.buffer(ioutil.NopCloser(bytes.NewReader([]byte(c))))
if err != nil {
t.Fatalf("case %q: %v", c, err)
}
value, err := ioutil.ReadAll(result)
if err != nil {
panic(err)
}
if string(value) != c {
t.Fatalf("expected %q, received %q", c, value)
}
}
}
Loading
Loading
@@ -31,7 +31,7 @@ type Config struct {
APILimit uint
APIQueueLimit uint
APIQueueTimeout time.Duration
RequestBufferSize uint
RequestBufferSize int
}
 
type Upstream struct {
Loading
Loading
Loading
Loading
@@ -45,7 +45,7 @@ var secretPath = flag.String("secretPath", "./.gitlab_workhorse_secret", "File w
var apiLimit = flag.Uint("apiLimit", 0, "Number of API requests allowed at single time")
var apiQueueLimit = flag.Uint("apiQueueLimit", 0, "Number of API requests allowed to be queued")
var apiQueueTimeout = flag.Duration("apiQueueDuration", queueing.DefaultTimeout, "Maximum queueing duration of requests")
var requestBufferSize = flag.Uint("requestBufferSize", 0, "Buffer size for request body buffers (0 means no buffering)")
var requestBufferSize = flag.Int("requestBufferSize", 0, "Buffer size for request body buffers (0 means no buffering)")
 
func main() {
flag.Usage = func() {
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment