fork of https://github.com/sourcegraph/zoekt
1package main
2
3import (
4 "context"
5 "io"
6 "net"
7 "net/http"
8
9 "github.com/mxk/go-flowrate/flowrate"
10)
11
12type connReadWriter struct {
13 net.Conn
14
15 Reader io.Reader
16 Writer io.Writer
17}
18
19func (c *connReadWriter) Read(b []byte) (int, error) {
20 return c.Reader.Read(b)
21}
22
23func (c *connReadWriter) Write(b []byte) (int, error) {
24 return c.Writer.Write(b)
25}
26
27type dial func(ctx context.Context, network, addr string) (net.Conn, error)
28
29func limitDial(d dial, limit int64) dial {
30 if limit <= 0 {
31 return d
32 }
33
34 return func(ctx context.Context, network, addr string) (net.Conn, error) {
35 conn, err := d(ctx, network, addr)
36 if err != nil {
37 return nil, err
38 }
39 return &connReadWriter{
40 Conn: conn,
41 Reader: flowrate.NewReader(conn, limit),
42 Writer: flowrate.NewWriter(conn, limit),
43 }, nil
44 }
45}
46
47func limitHTTPDefaultClient(limitMbps int64) {
48 if limitMbps <= 0 {
49 return
50 }
51
52 const megabit = 1000 * 1000
53 limit := (limitMbps * megabit) / 8
54
55 t := http.DefaultTransport.(*http.Transport)
56 t.DialContext = limitDial(t.DialContext, limit)
57}