summaryrefslogtreecommitdiff
path: root/lib/profile/cobra.go
blob: 30940153a9fc55ade098939dfabc1d53d9fdb439 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright (C) 2023  Luke Shumaker <lukeshu@lukeshu.com>
//
// SPDX-License-Identifier: GPL-2.0-or-later

package profile

import (
	"io"
	"os"

	"github.com/datawire/dlib/derror"
	"github.com/spf13/cobra"
	"github.com/spf13/pflag"
)

type flagSet struct {
	shutdown []StopFunc
}

func (fs *flagSet) Stop() error {
	var errs derror.MultiError
	for _, fn := range fs.shutdown {
		if err := fn(); err != nil {
			errs = append(errs, err)
		}
	}
	if len(errs) > 0 {
		return errs
	}
	return nil
}

type flagValue struct {
	parent *flagSet
	start  startFunc
	curVal string
}

var _ pflag.Value = (*flagValue)(nil)

// String implements pflag.Value.
func (fv *flagValue) String() string { return fv.curVal }

// Set implements pflag.Value.
func (fv *flagValue) Set(filename string) error {
	if filename == "" {
		return nil
	}
	w, err := os.Create(filename)
	if err != nil {
		return err
	}
	shutdown, err := fv.start(w)
	if err != nil {
		return err
	}
	fv.curVal = filename
	fv.parent.shutdown = append(fv.parent.shutdown, func() error {
		err1 := shutdown()
		err2 := w.Close()
		if err1 != nil {
			return err1
		}
		return err2
	})
	return nil
}

// Type implements pflag.Value.
func (fv *flagValue) Type() string { return "filename" }

func pStart(name string) startFunc {
	return func(w io.Writer) (StopFunc, error) {
		return Profile(w, name)
	}
}

// AddProfileFlags adds flags to a pflag.FlagSet to write any (or all)
// of the standard profiles to a file, and returns a "stop" function
// to be called at program shutdown.
func AddProfileFlags(flags *pflag.FlagSet, prefix string) StopFunc {
	var root flagSet

	flags.Var(&flagValue{parent: &root, start: CPU}, prefix+"cpu", "Write a CPU profile to the file `cpu.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"cpu")

	flags.Var(&flagValue{parent: &root, start: Trace}, prefix+"trace", "Write a trace (https://pkg.go.dev/runtime/trace) to the file `trace.out`")
	_ = cobra.MarkFlagFilename(flags, prefix+"trace")

	flags.Var(&flagValue{parent: &root, start: pStart("goroutine")}, prefix+"goroutine", "Write a goroutine profile to the file `goroutine.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"goroutine")

	flags.Var(&flagValue{parent: &root, start: pStart("threadcreate")}, prefix+"threadcreate", "Write a threadcreate profile to the file `threadcreate.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"threadcreate")

	flags.Var(&flagValue{parent: &root, start: pStart("heap")}, prefix+"heap", "Write a heap profile to the file `heap.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"heap")

	flags.Var(&flagValue{parent: &root, start: pStart("allocs")}, prefix+"allocs", "Write an allocs profile to the file `allocs.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"allocs")

	flags.Var(&flagValue{parent: &root, start: pStart("block")}, prefix+"block", "Write a block profile to the file `block.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"block")

	flags.Var(&flagValue{parent: &root, start: pStart("mutex")}, prefix+"mutex", "Write a mutex profile to the file `mutex.pprof`")
	_ = cobra.MarkFlagFilename(flags, prefix+"mutex")

	return root.Stop
}