fork of https://github.com/sourcegraph/zoekt
1// Copyright 2016 Google Inc. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Command zoekt-webserver responds to search queries, using an index generated
16// by another program such as zoekt-indexserver.
17
18package main
19
20import (
21 "context"
22 "crypto/tls"
23 "errors"
24 "flag"
25 "fmt"
26 "html/template"
27 "log"
28 "net"
29 "net/http"
30 "net/http/httputil"
31 "net/url"
32 "os"
33 "os/signal"
34 "path/filepath"
35 "runtime"
36 "strconv"
37 "strings"
38 "time"
39
40 "github.com/sourcegraph/mountinfo"
41 "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
42 "golang.org/x/net/http2"
43 "golang.org/x/net/http2/h2c"
44 "google.golang.org/grpc"
45
46 "github.com/sourcegraph/zoekt"
47 "github.com/sourcegraph/zoekt/build"
48 "github.com/sourcegraph/zoekt/debugserver"
49 zoektgrpc "github.com/sourcegraph/zoekt/grpc"
50 v1 "github.com/sourcegraph/zoekt/grpc/v1"
51 "github.com/sourcegraph/zoekt/internal/profiler"
52 "github.com/sourcegraph/zoekt/internal/tracer"
53 "github.com/sourcegraph/zoekt/query"
54 "github.com/sourcegraph/zoekt/shards"
55 "github.com/sourcegraph/zoekt/stream"
56 "github.com/sourcegraph/zoekt/trace"
57 "github.com/sourcegraph/zoekt/web"
58
59 "github.com/opentracing/opentracing-go"
60 "github.com/prometheus/client_golang/prometheus"
61 "github.com/prometheus/client_golang/prometheus/promauto"
62 "github.com/shirou/gopsutil/v3/disk"
63 sglog "github.com/sourcegraph/log"
64 "github.com/uber/jaeger-client-go"
65 oteltrace "go.opentelemetry.io/otel/trace"
66 "go.uber.org/automaxprocs/maxprocs"
67)
68
69const logFormat = "2006-01-02T15-04-05.999999999Z07"
70
71func divertLogs(dir string, interval time.Duration) {
72 t := time.NewTicker(interval)
73 var last *os.File
74 for {
75 nm := filepath.Join(dir, fmt.Sprintf("zoekt-webserver.%s.%d.log", time.Now().Format(logFormat), os.Getpid()))
76 fmt.Fprintf(os.Stderr, "writing logs to %s\n", nm)
77
78 f, err := os.Create(nm)
79 if err != nil {
80 // There is not much we can do now.
81 fmt.Fprintf(os.Stderr, "can't create output file %s: %v\n", nm, err)
82 os.Exit(2)
83 }
84
85 log.SetOutput(f)
86 last.Close()
87
88 last = f
89
90 <-t.C
91 }
92}
93
94const templateExtension = ".html.tpl"
95
96func loadTemplates(tpl *template.Template, dir string) error {
97 fs, err := filepath.Glob(dir + "/*" + templateExtension)
98 if err != nil {
99 log.Fatalf("Glob: %v", err)
100 }
101
102 log.Printf("loading templates: %v", fs)
103 for _, fn := range fs {
104 content, err := os.ReadFile(fn)
105 if err != nil {
106 return err
107 }
108
109 base := filepath.Base(fn)
110 base = strings.TrimSuffix(base, templateExtension)
111 if _, err := tpl.New(base).Parse(string(content)); err != nil {
112 return fmt.Errorf("template.Parse(%s): %v", fn, err)
113 }
114 }
115 return nil
116}
117
118func writeTemplates(dir string) error {
119 if dir == "" {
120 return fmt.Errorf("must set --template_dir")
121 }
122
123 for k, v := range web.TemplateText {
124 nm := filepath.Join(dir, k+templateExtension)
125 if err := os.WriteFile(nm, []byte(v), 0o644); err != nil {
126 return err
127 }
128 }
129 return nil
130}
131
132func main() {
133 logDir := flag.String("log_dir", "", "log to this directory rather than stderr.")
134 logRefresh := flag.Duration("log_refresh", 24*time.Hour, "if using --log_dir, start writing a new file this often.")
135
136 listen := flag.String("listen", ":6070", "listen on this address.")
137 index := flag.String("index", build.DefaultDir, "set index directory to use")
138 html := flag.Bool("html", true, "enable HTML interface")
139 enableRPC := flag.Bool("rpc", false, "enable go/net RPC")
140 enableIndexserverProxy := flag.Bool("indexserver_proxy", false, "proxy requests with URLs matching the path /indexserver/ to <index>/indexserver.sock")
141 print := flag.Bool("print", false, "enable local result URLs")
142 enablePprof := flag.Bool("pprof", false, "set to enable remote profiling.")
143 sslCert := flag.String("ssl_cert", "", "set path to SSL .pem holding certificate.")
144 sslKey := flag.String("ssl_key", "", "set path to SSL .pem holding key.")
145 hostCustomization := flag.String(
146 "host_customization", "",
147 "specify host customization, as HOST1=QUERY,HOST2=QUERY")
148
149 templateDir := flag.String("template_dir", "", "set directory from which to load custom .html.tpl template files")
150 dumpTemplates := flag.Bool("dump_templates", false, "dump templates into --template_dir and exit.")
151 version := flag.Bool("version", false, "Print version number")
152
153 flag.Parse()
154
155 if *version {
156 fmt.Printf("zoekt-webserver version %q\n", zoekt.Version)
157 os.Exit(0)
158 }
159
160 if *dumpTemplates {
161 if err := writeTemplates(*templateDir); err != nil {
162 log.Fatal(err)
163 }
164 os.Exit(0)
165 }
166
167 liblog := sglog.Init(sglog.Resource{
168 Name: "zoekt-webserver",
169 Version: zoekt.Version,
170 InstanceID: os.Getenv("HOSTNAME"),
171 })
172 defer liblog.Sync()
173 tracer.Init("zoekt-webserver", zoekt.Version)
174 profiler.Init("zoekt-webserver", zoekt.Version, -1)
175
176 if *logDir != "" {
177 if fi, err := os.Lstat(*logDir); err != nil || !fi.IsDir() {
178 log.Fatalf("%s is not a directory", *logDir)
179 }
180 // We could do fdup acrobatics to also redirect
181 // stderr, but it is simpler and more portable for the
182 // caller to divert stderr output if necessary.
183 go divertLogs(*logDir, *logRefresh)
184 }
185
186 // Tune GOMAXPROCS to match Linux container CPU quota.
187 _, _ = maxprocs.Set()
188
189 if err := os.MkdirAll(*index, 0o755); err != nil {
190 log.Fatal(err)
191 }
192
193 mustRegisterDiskMonitor(*index)
194
195 metricsLogger := sglog.Scoped("metricsRegistration", "")
196
197 mustRegisterMemoryMapMetrics(metricsLogger)
198
199 opts := mountinfo.CollectorOpts{Namespace: "zoekt_webserver"}
200 c := mountinfo.NewCollector(metricsLogger, opts, map[string]string{"indexDir": *index})
201
202 prometheus.DefaultRegisterer.MustRegister(c)
203
204 // Do not block on loading shards so we can become partially available
205 // sooner. Otherwise on large instances zoekt can be unavailable on the
206 // order of minutes.
207 searcher, err := shards.NewDirectorySearcherFast(*index)
208 if err != nil {
209 log.Fatal(err)
210 }
211
212 searcher = &loggedSearcher{
213 Streamer: searcher,
214 Logger: sglog.Scoped("searcher", ""),
215 }
216
217 s := &web.Server{
218 Searcher: searcher,
219 Top: web.Top,
220 Version: zoekt.Version,
221 }
222
223 if *templateDir != "" {
224 if err := loadTemplates(s.Top, *templateDir); err != nil {
225 log.Fatalf("loadTemplates: %v", err)
226 }
227 }
228
229 s.Print = *print
230 s.HTML = *html
231 s.RPC = *enableRPC
232
233 if *hostCustomization != "" {
234 s.HostCustomQueries = map[string]string{}
235 for _, h := range strings.SplitN(*hostCustomization, ",", -1) {
236 if len(h) == 0 {
237 continue
238 }
239 fields := strings.SplitN(h, "=", 2)
240 if len(fields) < 2 {
241 log.Fatalf("invalid host_customization %q", h)
242 }
243
244 s.HostCustomQueries[fields[0]] = fields[1]
245 }
246 }
247
248 serveMux, err := web.NewMux(s)
249 if err != nil {
250 log.Fatal(err)
251 }
252
253 debugserver.AddHandlers(serveMux, *enablePprof)
254
255 if *enableIndexserverProxy {
256 socket := filepath.Join(*index, "indexserver.sock")
257 sglog.Scoped("server", "").Info("adding reverse proxy", sglog.String("socket", socket))
258 addProxyHandler(serveMux, socket)
259 }
260
261 handler := trace.Middleware(serveMux)
262
263 // Sourcegraph: We use environment variables to configure watchdog since
264 // they are more convenient than flags in containerized environments.
265 watchdogTick := 30 * time.Second
266 if v := os.Getenv("ZOEKT_WATCHDOG_TICK"); v != "" {
267 watchdogTick, _ = time.ParseDuration(v)
268 log.Printf("custom ZOEKT_WATCHDOG_TICK=%v", watchdogTick)
269 }
270
271 watchdogErrCount := 3
272 if v := os.Getenv("ZOEKT_WATCHDOG_ERRORS"); v != "" {
273 watchdogErrCount, _ = strconv.Atoi(v)
274 log.Printf("custom ZOEKT_WATCHDOG_ERRORS=%d", watchdogErrCount)
275 }
276
277 watchdogAddr := "http://" + *listen
278 if *sslCert != "" || *sslKey != "" {
279 watchdogAddr = "https://" + *listen
280 }
281 watchdogAddr += "/healthz"
282
283 if watchdogErrCount > 0 && watchdogTick > 0 {
284 go watchdog(watchdogTick, watchdogErrCount, watchdogAddr)
285 } else {
286 log.Println("watchdog disabled")
287 }
288
289 grpcServer := grpc.NewServer(
290 grpc.StreamInterceptor(otelgrpc.StreamServerInterceptor()),
291 grpc.UnaryInterceptor(otelgrpc.UnaryServerInterceptor()),
292 )
293 v1.RegisterWebserverServiceServer(grpcServer, zoektgrpc.NewServer(web.NewTraceAwareSearcher(s.Searcher)))
294
295 handler = multiplexGRPC(grpcServer, handler)
296
297 srv := &http.Server{
298 Addr: *listen,
299 Handler: handler,
300 }
301
302 go func() {
303 sglog.Scoped("server", "").Info("starting server", sglog.Stringp("address", listen))
304 var err error
305 if *sslCert != "" || *sslKey != "" {
306 err = srv.ListenAndServeTLS(*sslCert, *sslKey)
307 } else {
308 err = srv.ListenAndServe()
309 }
310
311 if err != http.ErrServerClosed {
312 // Fatal otherwise shutdownOnSignal will block
313 log.Fatalf("ListenAndServe: %v", err)
314 }
315 }()
316
317 if s.RPC {
318 // Our RPC system does not support shutdown and hijacks the underlying
319 // http connection. This means shutdown is ineffective and just waits 10s
320 // before calling close. Lets just quit faster in that case.
321 if err := closeOnSignal(srv); err != nil {
322 log.Fatalf("http.Server.Close: %v", err)
323 }
324 } else {
325 if err := shutdownOnSignal(srv); err != nil {
326 log.Fatalf("http.Server.Shutdown: %v", err)
327 }
328 }
329}
330
331// multiplexGRPC takes a gRPC server and a plain HTTP handler and multiplexes the
332// request handling. Any requests that declare themselves as gRPC requests are routed
333// to the gRPC server, all others are routed to the httpHandler.
334func multiplexGRPC(grpcServer *grpc.Server, httpHandler http.Handler) http.Handler {
335 newHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
336 if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
337 grpcServer.ServeHTTP(w, r)
338 } else {
339 httpHandler.ServeHTTP(w, r)
340 }
341 })
342
343 // Until we enable TLS, we need to fall back to the h2c protocol, which is
344 // basically HTTP2 without TLS. The standard library does not implement the
345 // h2s protocol, so this hijacks h2s requests and handles them correctly.
346 return h2c.NewHandler(newHandler, &http2.Server{})
347}
348
349// addProxyHandler adds a handler to "mux" that proxies all requests with base
350// /indexserver to "socket".
351func addProxyHandler(mux *http.ServeMux, socket string) {
352 proxy := httputil.NewSingleHostReverseProxy(&url.URL{
353 Scheme: "http",
354 // The value of "Host" is arbitrary, because it is ignored by the
355 // DialContext we use for the socket connection.
356 Host: "socket",
357 })
358 proxy.Transport = &http.Transport{
359 DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
360 var d net.Dialer
361 return d.DialContext(ctx, "unix", socket)
362 },
363 }
364 mux.Handle("/indexserver/", http.StripPrefix("/indexserver/", http.HandlerFunc(proxy.ServeHTTP)))
365}
366
367// shutdownSignalChan returns a channel which is listening for shutdown
368// signals from the operating system. maxReads is an upper bound on how many
369// times you will read the channel (used as buffer for signal.Notify).
370func shutdownSignalChan(maxReads int) <-chan os.Signal {
371 c := make(chan os.Signal, maxReads)
372 signal.Notify(c, os.Interrupt) // terminal C-c and goreman
373 signal.Notify(c, PLATFORM_SIGTERM) // Kubernetes
374 return c
375}
376
377// closeOnSignal will listen for SIGINT or SIGTERM and call srv.Close. This is
378// not a graceful shutdown, see shutdownOnSignal.
379func closeOnSignal(srv *http.Server) error {
380 c := shutdownSignalChan(1)
381 <-c
382
383 return srv.Close()
384}
385
386// shutdownOnSignal will listen for SIGINT or SIGTERM and call srv.Shutdown.
387// Note it doesn't call anything else for shutting down. Notably our RPC
388// framework doesn't allow us to drain connections, so when Shutdown we will
389// wait 10s before closing.
390//
391// Note: the call site for shutdownOnSignal should use closeOnSignal instead
392// if rpc mode is enabled due to the above limitation.
393func shutdownOnSignal(srv *http.Server) error {
394 c := shutdownSignalChan(2)
395 <-c
396
397 // If we receive another signal, immediate shutdown
398 ctx, cancel := context.WithCancel(context.Background())
399 defer cancel()
400 go func() {
401 select {
402 case <-ctx.Done():
403 case sig := <-c:
404 log.Printf("received another signal (%v), immediate shutdown", sig)
405 cancel()
406 }
407 }()
408
409 // Wait for 10s to drain ongoing requests. Kubernetes gives us 30s to
410 // shutdown, we have already used 15s waiting for our endpoint removal to
411 // propagate.
412 ctx, cancel2 := context.WithTimeout(ctx, 10*time.Second)
413 defer cancel2()
414
415 log.Printf("shutting down")
416 return srv.Shutdown(ctx)
417}
418
419func watchdogOnce(ctx context.Context, client *http.Client, addr string) error {
420 defer metricWatchdogTotal.Inc()
421
422 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(30*time.Second))
423 defer cancel()
424
425 req, err := http.NewRequest("GET", addr, nil)
426 if err != nil {
427 return err
428 }
429
430 req = req.WithContext(ctx)
431
432 resp, err := client.Do(req)
433 if err != nil {
434 return err
435 }
436
437 if resp.StatusCode != http.StatusOK {
438 return fmt.Errorf("watchdog: status %v", resp.StatusCode)
439 }
440 return nil
441}
442
443func watchdog(dt time.Duration, maxErrCount int, addr string) {
444 tr := &http.Transport{
445 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
446 }
447 client := &http.Client{
448 Transport: tr,
449 }
450 tick := time.NewTicker(dt)
451
452 errCount := 0
453 for range tick.C {
454 err := watchdogOnce(context.Background(), client, addr)
455 if err != nil {
456 errCount++
457 metricWatchdogErrors.Set(float64(errCount))
458 metricWatchdogErrorsTotal.Inc()
459 if errCount >= maxErrCount {
460 log.Panicf("watchdog: %v", err)
461 } else {
462 log.Printf("watchdog: failed, will try %d more times: %v", maxErrCount-errCount, err)
463 }
464 } else if errCount > 0 {
465 errCount = 0
466 metricWatchdogErrors.Set(float64(errCount))
467 log.Printf("watchdog: success, resetting error count")
468 }
469 }
470}
471
472func diskUsage(path string) (*disk.UsageStat, error) {
473 duPath := path
474 if runtime.GOOS == "windows" {
475 duPath = filepath.VolumeName(duPath)
476 }
477 usage, err := disk.Usage(duPath)
478 if err != nil {
479 return nil, fmt.Errorf("diskUsage: %w", err)
480 }
481 return usage, err
482}
483
484func mustRegisterDiskMonitor(path string) {
485 prometheus.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
486 Name: "src_disk_space_available_bytes",
487 Help: "Amount of free space disk space.",
488 ConstLabels: prometheus.Labels{"path": path},
489 }, func() float64 {
490 // I know there is no error handling here, and I don't like it
491 // but there was no error handling in the previous version
492 // that used Statfs, either, so I'm assuming there's no need for it
493 usage, _ := diskUsage(path)
494 return float64(usage.Free)
495 }))
496
497 prometheus.MustRegister(prometheus.NewGaugeFunc(prometheus.GaugeOpts{
498 Name: "src_disk_space_total_bytes",
499 Help: "Amount of total disk space.",
500 ConstLabels: prometheus.Labels{"path": path},
501 }, func() float64 {
502 // I know there is no error handling here, and I don't like it
503 // but there was no error handling in the previous version
504 // that used Statfs, either, so I'm assuming there's no need for it
505 usage, _ := diskUsage(path)
506 return float64(usage.Total)
507 }))
508}
509
510type loggedSearcher struct {
511 zoekt.Streamer
512 Logger sglog.Logger
513}
514
515func (s *loggedSearcher) Search(
516 ctx context.Context,
517 q query.Q,
518 opts *zoekt.SearchOptions,
519) (sr *zoekt.SearchResult, err error) {
520 defer func() {
521 var stats *zoekt.Stats
522 if sr != nil {
523 stats = &sr.Stats
524 }
525 s.log(ctx, q, opts, stats, err)
526 }()
527
528 metricSearchRequestsTotal.Inc()
529 return s.Streamer.Search(ctx, q, opts)
530}
531
532func (s *loggedSearcher) StreamSearch(
533 ctx context.Context,
534 q query.Q,
535 opts *zoekt.SearchOptions,
536 sender zoekt.Sender,
537) error {
538 var (
539 stats zoekt.Stats
540 )
541
542 metricSearchRequestsTotal.Inc()
543 err := s.Streamer.StreamSearch(ctx, q, opts, stream.SenderFunc(func(event *zoekt.SearchResult) {
544 stats.Add(event.Stats)
545 sender.Send(event)
546 }))
547
548 s.log(ctx, q, opts, &stats, err)
549
550 return err
551}
552
553func (s *loggedSearcher) log(ctx context.Context, q query.Q, opts *zoekt.SearchOptions, st *zoekt.Stats, err error) {
554 logger := s.Logger.
555 WithTrace(traceContext(ctx)).
556 With(
557 sglog.String("query", q.String()),
558 sglog.Bool("opts.EstimateDocCount", opts.EstimateDocCount),
559 sglog.Bool("opts.Whole", opts.Whole),
560 sglog.Int("opts.ShardMaxMatchCount", opts.ShardMaxMatchCount),
561 sglog.Int("opts.TotalMaxMatchCount", opts.TotalMaxMatchCount),
562 sglog.Duration("opts.MaxWallTime", opts.MaxWallTime),
563 sglog.Int("opts.MaxDocDisplayCount", opts.MaxDocDisplayCount),
564 )
565
566 if err != nil {
567 switch {
568 case errors.Is(err, context.Canceled):
569 logger.Warn("search canceled", sglog.Error(err))
570 case errors.Is(err, context.DeadlineExceeded):
571 logger.Warn("search timeout", sglog.Error(err))
572 default:
573 logger.Error("search failed", sglog.Error(err))
574 }
575 return
576 }
577
578 if st == nil {
579 return
580 }
581
582 logger.Debug("search",
583 sglog.Int64("stat.ContentBytesLoaded", st.ContentBytesLoaded),
584 sglog.Int64("stat.IndexBytesLoaded", st.IndexBytesLoaded),
585 sglog.Int("stat.Crashes", st.Crashes),
586 sglog.Duration("stat.Duration", st.Duration),
587 sglog.Int("stat.FileCount", st.FileCount),
588 sglog.Int("stat.ShardFilesConsidered", st.ShardFilesConsidered),
589 sglog.Int("stat.FilesConsidered", st.FilesConsidered),
590 sglog.Int("stat.FilesLoaded", st.FilesLoaded),
591 sglog.Int("stat.FilesSkipped", st.FilesSkipped),
592 sglog.Int("stat.ShardsScanned", st.ShardsScanned),
593 sglog.Int("stat.ShardsSkipped", st.ShardsSkipped),
594 sglog.Int("stat.ShardsSkippedFilter", st.ShardsSkippedFilter),
595 sglog.Int("stat.MatchCount", st.MatchCount),
596 sglog.Int("stat.NgramMatches", st.NgramMatches),
597 sglog.Int("stat.NgramLookups", st.NgramLookups),
598 sglog.Duration("stat.Wait", st.Wait),
599 sglog.Int("stat.RegexpsConsidered", st.RegexpsConsidered),
600 sglog.String("stat.FlushReason", st.FlushReason.String()),
601 )
602}
603
604func traceContext(ctx context.Context) sglog.TraceContext {
605 otSpan := opentracing.SpanFromContext(ctx)
606 if otSpan != nil {
607 if jaegerSpan, ok := otSpan.Context().(jaeger.SpanContext); ok {
608 return sglog.TraceContext{
609 TraceID: jaegerSpan.TraceID().String(),
610 SpanID: jaegerSpan.SpanID().String(),
611 }
612 }
613 }
614
615 if otelSpan := oteltrace.SpanFromContext(ctx).SpanContext(); otelSpan.IsValid() {
616 return sglog.TraceContext{
617 TraceID: otelSpan.TraceID().String(),
618 SpanID: otelSpan.SpanID().String(),
619 }
620 }
621
622 return sglog.TraceContext{}
623}
624
625var (
626 metricWatchdogErrors = promauto.NewGauge(prometheus.GaugeOpts{
627 Name: "zoekt_webserver_watchdog_errors",
628 Help: "The current error count for zoekt watchdog.",
629 })
630 metricWatchdogTotal = promauto.NewCounter(prometheus.CounterOpts{
631 Name: "zoekt_webserver_watchdog_total",
632 Help: "The total number of requests done by zoekt watchdog.",
633 })
634 metricWatchdogErrorsTotal = promauto.NewCounter(prometheus.CounterOpts{
635 Name: "zoekt_webserver_watchdog_errors_total",
636 Help: "The total number of errors from zoekt watchdog.",
637 })
638 metricSearchRequestsTotal = promauto.NewCounter(prometheus.CounterOpts{
639 Name: "zoekt_search_requests_total",
640 Help: "The total number of search requests that zoekt received",
641 })
642)