package testo
import (
"reflect"
"runtime"
"slices"
"strings"
"sync"
"github.com/ozontech/testo/internal/reflectutil"
"github.com/ozontech/testo/testoplugin"
)
type testID string
var (
globalAnnotationsMu sync.RWMutex
globalAnnotations = make(map[testID][]testoplugin.Option)
)
// ForEach applies static options (annotations) for each case of a given parametrized test.
//
// Similar to [For], but for parametrized tests.
//
// testo.ForEach(MySuite.TestFoo, someplugin.WithThis(...), otherplugin.WithThat(...))
// testo.ForEach((*Suite).TestFoo, myplugin.WithRetry())
func ForEach[Suite suite[T], T CommonT, P any](
test func(Suite, T, P),
options ...testoplugin.Option,
) struct{} {
annotate(getID[Suite](reflect.ValueOf(test)), options...)
return struct{}{}
}
// For applies static options (annotations) for a given test.
//
// Test annotations is a slice of options to be passed for this test.
// Test annotations are available to plugins before running an actual test,
// thus enhancing test planning features.
//
// Multiple annotation calls to the same test will append options.
//
// testo.For(MySuite.TestFoo, someplugin.WithThis(...), otherplugin.WithThat(...))
// testo.For((*Suite).TestFoo, myplugin.WithRetry())
//
// NOTE: it returns an empty struct{} value to enable the following usage:
//
// var _ = testo.For(MySuite.TestFoo, someplugin.WithSomeOption(true))
//
// func (MySuite) TestFoo(t T) { ... }
//
// That "var _ =" construction would not be possible otherwise.
// This is slightly less verbose than using an init function:
//
// func init() {
// testo.For(MySuite.TestFoo, someplugin.WithSomeOption(true))
// }
func For[Suite suite[T], T CommonT](
test func(Suite, T),
options ...testoplugin.Option,
) struct{} {
annotate(getID[Suite](reflect.ValueOf(test)), options...)
return struct{}{}
}
func annotate(id testID, options ...testoplugin.Option) {
globalAnnotationsMu.Lock()
defer globalAnnotationsMu.Unlock()
globalAnnotations[id] = append(globalAnnotations[id], options...)
}
func annotationsFor(id testID) []testoplugin.Option {
globalAnnotationsMu.RLock()
defer globalAnnotationsMu.RUnlock()
return slices.Clone(globalAnnotations[id])
}
func getID[Suite any](test reflect.Value) testID {
suiteName := reflectutil.NameOf[Suite]()
name := runtime.FuncForPC(test.Pointer()).Name()
name = strings.ReplaceAll(name, "(*"+suiteName+")", suiteName)
return testID(name)
}
package testo
import (
"fmt"
"maps"
"os"
"reflect"
"slices"
"strings"
"testing"
"unicode"
"unicode/utf8"
"github.com/ozontech/testo/internal/pragma"
"github.com/ozontech/testo/internal/testnamer"
"github.com/ozontech/testo/testoplugin"
"github.com/ozontech/testo/testoreflect"
)
type (
suiteTest[Suite suite[T], T CommonT] struct {
Name string
Info testoreflect.TestInfo
Run func(Suite, T)
}
suiteCase[Suite suite[T], T CommonT] struct {
Provides reflect.Type
Func func(Suite) []reflect.Value
}
)
var _ testoplugin.PlannedTest = (*plannedSuiteTest[Suite[*T], *T])(nil)
type plannedSuiteTest[Suite suite[T], T CommonT] struct {
inner annotatedSuiteTest[Suite, T]
}
func (plannedSuiteTest[Suite, T]) TestoInternal(pragma.DoNotImplement) {}
func (t plannedSuiteTest[Suite, T]) Info() testoreflect.TestInfo {
return t.inner.Info
}
func (t plannedSuiteTest[Suite, T]) Annotations() []testoplugin.Option {
return slices.Clone(t.inner.Options)
}
// isTest states whether name is a valid test name (or other type, according to prefix).
//
// It checks if the next character after prefix is uppercase.
//
// TestFoo => true
// Test => true
// TestfooBar => false
func isTest(name, prefix string) bool {
if !strings.HasPrefix(name, prefix) {
return false
}
// "Test" is ok
if len(name) == len(prefix) {
return true
}
r, _ := utf8.DecodeRuneInString(name[len(prefix):])
return !unicode.IsLower(r)
}
func suiteCasesOf[Suite suite[T], T CommonT](tb testing.TB) map[string]suiteCase[Suite, T] {
tb.Helper()
vt := reflect.TypeFor[Suite]()
cases := make(map[string]suiteCase[Suite, T])
for i := range vt.NumMethod() {
method := vt.Method(i)
const prefix = "Cases"
if !isTest(method.Name, prefix) {
continue
}
name := strings.TrimPrefix(method.Name, prefix)
if name == "" {
continue
}
isValidIn := method.Type.NumIn() == 1
isValidOut := method.Type.NumOut() == 1 && method.Type.Out(0).Kind() == reflect.Slice
if !isValidIn || !isValidOut {
tb.Fatalf(
"testo: wrong signature for %[1]s.%[2]s, must be: func (%[1]s) %[2]s() []...",
reflect.TypeFor[Suite](), method.Name, tb,
)
}
cases[name] = suiteCase[Suite, T]{
Provides: method.Type.Out(0).Elem(),
Func: func(s Suite) []reflect.Value {
var suite reflect.Value
if method.Type.In(0).Kind() == reflect.Pointer &&
reflect.TypeOf(s).Kind() != reflect.Pointer {
suite = reflect.ValueOf(&s)
} else {
suite = reflect.ValueOf(s)
}
slice := method.Func.Call([]reflect.Value{suite})[0]
values := make([]reflect.Value, 0, slice.Len())
for i := range slice.Len() {
v := slice.Index(i)
values = append(values, v)
}
return values
},
}
}
return cases
}
// suiteTests contains all the suite tests.
//
// While regular tests are ready to be run,
// parametrized tests are tricky.
//
// We can't know how many permutations (hence number of tests)
// they will have until we receive all values for each case by calling CasesXXX funcs.
// However, we can't do that before running the BeforeAll hooks - cases funcs may
// depend on in being run first.
//
// But we should not run any hooks until we are sure that tests are correct
// and no error should be raised.
//
// That's why we statically analyze parametrized tests signatures,
// but delay the actual collection for later.
type suiteTests[Suite suite[T], T CommonT] struct {
Regular []suiteTest[Suite, T]
Parametrized []suiteTestParametrized[Suite, T]
}
type annotatedSuiteTest[Suite suite[T], T CommonT] struct {
suiteTest[Suite, T]
Options []testoplugin.Option
}
// Collect all suite tests.
//
// Suite instance is required here to get
// parameter cases (CasesXXX funcs), not to invoke the actual tests.
func (st suiteTests[Suite, T]) Collect(s Suite) []annotatedSuiteTest[Suite, T] {
tests := make([]annotatedSuiteTest[Suite, T], 0, len(st.Regular))
for _, r := range st.Regular {
tests = append(tests, annotatedSuiteTest[Suite, T]{
suiteTest: r,
Options: annotationsFor(getID[Suite](reflect.ValueOf(r.Run))),
})
}
for _, p := range st.Parametrized {
cases := p.Tests(s)
tests = append(tests, cases...)
}
return tests
}
type testsCollector[Suite suite[T], T CommonT] struct {
CallerName string
TestNamer *testnamer.Namer
}
func (tc *testsCollector[Suite, T]) testName(base string) string {
return tc.TestNamer.Name(tc.CallerName, base)
}
//nolint:cyclop,funlen,gocognit // splitting it would make it even more complex
func (tc *testsCollector[Suite, T]) Collect(
tb testing.TB,
) suiteTests[Suite, T] {
tb.Helper()
cases := suiteCasesOf[Suite](tb)
suite := reflect.TypeFor[Suite]()
var tests suiteTests[Suite, T]
for i := range suite.NumMethod() {
method := suite.Method(i)
if !isTest(method.Name, "Test") {
continue
}
raiseWrongSignatureError := func() {
tb.Helper()
//nolint:lll // it's a long message
tb.Fatalf(
"testo: wrong signature for (%[1]s).%[2]s, must be: func (%[1]s).%[2]s(%[3]s) or func (%[1]s).%[2]s(%[3]s, struct{...})",
suite,
method.Name,
reflect.TypeFor[T](),
)
}
if method.Type.NumOut() != 0 {
raiseWrongSignatureError()
}
if method.Type.NumIn() < 2 {
raiseWrongSignatureError()
}
if method.Type.In(1) != reflect.TypeFor[T]() {
raiseWrongSignatureError()
}
switch method.Type.NumIn() {
default:
raiseWrongSignatureError()
case 2: // regular test - (Suite, T)
if !flagMethod.MatchString(method.Name) {
continue
}
tests.Regular = append(tests.Regular, suiteTest[Suite, T]{
Name: method.Name,
Info: testoreflect.RegularTestInfo{
Name: tc.testName(method.Name),
RawBaseName: method.Name,
Level: 1,
FuncPC: method.Func.Pointer(),
},
Run: method.Func.Interface().(func(Suite, T)),
})
case 3: // parametrized test - (Suite, T, Params)
param := method.Type.In(2)
if param.Kind() != reflect.Struct {
raiseWrongSignatureError()
}
requiredCases := make(map[string]suiteCase[Suite, T])
for i := range param.NumField() {
field := param.Field(i)
c, ok := cases[field.Name]
if !ok {
tb.Fatalf(
"testo: wrong param signature for (%[1]s).%[2]s: missing (%[1]s).Cases%[3]s() []%s for param %[3]q",
reflect.TypeFor[Suite](),
method.Name,
field.Name,
field.Type,
)
}
if !c.Provides.AssignableTo(field.Type) {
//nolint:lll // splitting string into multiple lines is worse
tb.Fatalf(
"testo: wrong param signature for (%[1]s).%[2]s: (%[1]s).Cases%[3]s provides %[4]s values, not assignable to param %[3]q of type %[5]s",
reflect.TypeFor[Suite](),
method.Name,
field.Name,
c.Provides,
field.Type,
)
}
requiredCases[field.Name] = c
}
if !flagMethod.MatchString(method.Name) {
continue
}
tests.Parametrized = append(
tests.Parametrized,
tc.newParametrizedTest(method, requiredCases),
)
}
}
return tests
}
type suiteTestParametrized[Suite suite[T], T CommonT] struct {
Tests func(Suite) []annotatedSuiteTest[Suite, T]
}
//nolint:funlen // no way to reduce length without losing readability
func (tc *testsCollector[Suite, T]) newParametrizedTest(
method reflect.Method,
cases map[string]suiteCase[Suite, T],
) suiteTestParametrized[Suite, T] {
return suiteTestParametrized[Suite, T]{
Tests: func(s Suite) []annotatedSuiteTest[Suite, T] {
casesValues := make(map[string][]reflect.Value, len(cases))
for caseName, c := range cases {
values := c.Func(s)
if len(values) == 0 {
structName := method.Type.In(0).String()
fmt.Fprintf(
os.Stderr,
"testo: warning: (%[1]s).Cases%[2]s provides zero values, (%[1]s).%[3]s will not run\n",
structName,
caseName,
method.Name,
)
return nil
}
casesValues[caseName] = values
}
permutations := casesPermutations(casesValues)
tests := make([]annotatedSuiteTest[Suite, T], 0, len(permutations))
for i, params := range permutations {
paramValue := reflect.New(method.Type.In(2)).Elem()
caseParams := make(map[string]any, len(params))
for paramName, value := range params {
paramValue.FieldByName(paramName).Set(value)
caseParams[paramName] = value.Interface()
}
tests = append(tests, annotatedSuiteTest[Suite, T]{
suiteTest: suiteTest[Suite, T]{
Name: method.Name,
Info: testoreflect.ParametrizedTestInfo{
Name: tc.testName(method.Name),
BaseName: method.Name,
Index: i,
CasesCount: len(permutations),
Params: caseParams,
FuncPC: method.Func.Pointer(),
},
Run: func(s Suite, t T) {
method.Func.Call([]reflect.Value{
reflect.ValueOf(s),
reflect.ValueOf(t),
paramValue,
})
},
},
Options: annotationsFor(getID[Suite](method.Func)),
})
}
return tests
},
}
}
// casesPermutations returns a determenistic permutations of the given cases values for test.
func casesPermutations[V any](v map[string][]V) []map[string]V {
permutationsCount := 1
keys := make([]string, 0, len(v))
for key, values := range v {
keys = append(keys, key)
permutationsCount *= len(values)
}
// Sort keys for determenistic output
slices.Sort(keys)
permutations := make([]map[string]V, 0, permutationsCount)
var generatePermutations func(current map[string]V, index int)
generatePermutations = func(current map[string]V, index int) {
if index == len(keys) {
permutations = append(permutations, maps.Clone(current))
return
}
key := keys[index]
for _, val := range v[key] {
current[key] = val
generatePermutations(current, index+1)
}
}
current := make(map[string]V)
generatePermutations(current, 0)
return permutations
}
package testo
import (
"fmt"
"reflect"
"github.com/ozontech/testo/internal/reflectutil"
"github.com/ozontech/testo/internal/stack"
"github.com/ozontech/testo/testoplugin"
)
//nolint:cyclop,funlen // this is the core of the whole framework
func construct[T CommonT](
t TestingT,
parent *T,
fill func(t *testoT),
options ...testoplugin.Option,
) T {
t.Helper()
seed := testoT{
common: t,
testingT: t,
levelOptions: options,
}
if parent != nil {
seed.parent = (*parent).unwrap()
}
if fill != nil {
fill(&seed)
}
// Passed T type may be an interface.
// In that case, we can not initialize it, since
// we don't know its underlying concrete type.
//
// However, parent is always a concrete value,
// so, if present, we extract that type from it.
realType := reflect.TypeFor[T]()
if parent != nil {
rv := reflect.ValueOf(*parent)
realType = rv.Type()
}
// special case when T is *testo.T
if realType == reflect.TypeFor[*testoT]() {
return any(&seed).(T)
}
value := reflect.New(realType)
if !reflectutil.Fill(value) {
panic(fmt.Sprintf(
"testo: infinite type recursion detected for %s inside %s",
reflectutil.FindRecursiveType(realType),
realType,
))
}
//nolint:nestif // TODO: factor out common logic
if parent == nil {
pluginTypes := make(map[reflect.Type]struct{})
collectPlugins(realType, pluginTypes)
delete(pluginTypes, realType)
plugins := make(map[reflect.Type]testoplugin.Plugin, len(pluginTypes))
for pluginType := range pluginTypes {
var child testoplugin.Plugin
if pluginType == reflect.TypeFor[*testoT]() {
child = &seed
} else {
v := reflect.New(pluginType.Elem())
reflectutil.Fill(v)
child = v.Interface().(testoplugin.Plugin)
}
plugins[pluginType] = child
}
seed.plugins = plugins
} else {
parentUnwrapped := (*parent).unwrap()
plugins := make(map[reflect.Type]testoplugin.Plugin, len(parentUnwrapped.plugins))
for pluginType := range parentUnwrapped.plugins {
var child testoplugin.Plugin
if pluginType == reflect.TypeFor[*testoT]() {
child = &seed
} else {
v := reflect.New(pluginType.Elem())
reflectutil.Fill(v)
child = v.Interface().(testoplugin.Plugin)
}
plugins[pluginType] = child
}
seed.plugins = plugins
}
specsStack := stack.New[typedPlugin]()
for pluginType, pluginValue := range seed.plugins {
specsStack.Push(typedPlugin{
Plugin: pluginValue,
Type: pluginType,
})
setPlugins(reflect.ValueOf(pluginValue), seed.plugins, &specsStack)
}
setPlugins(value, seed.plugins, &specsStack)
specs := make(map[reflect.Type]testoplugin.Spec, len(seed.plugins))
for {
p, ok := specsStack.Pop()
if !ok {
break
}
if _, ok := specs[p.Type]; ok {
continue
}
var parentPlugin testoplugin.Plugin
if parent != nil {
parentPlugin = (*parent).unwrap().plugins[p.Type]
} else {
parentPlugin = reflect.New(p.Type).Elem().Interface().(testoplugin.Plugin)
}
specs[p.Type] = p.Plugin.Plugin(parentPlugin, seed.options()...)
}
seed.spec = mergeSpecs(t, mapValues(specs)...)
return value.Elem().Interface().(T)
}
type typedPlugin struct {
testoplugin.Plugin
Type reflect.Type
}
func setPlugins(
v reflect.Value,
plugins map[reflect.Type]testoplugin.Plugin,
specs *stack.Stack[typedPlugin],
) {
if plugin, ok := plugins[v.Type().Elem()]; ok {
elem := v.Elem()
if !elem.IsValid() {
panic(fmt.Sprintf("testo: invalid elem for %s", v.Type()))
}
if !elem.CanSet() {
// TODO(metafates): add path to the field so that it is clear where error happens
panic(fmt.Sprintf("testo: can't set value for %s", v.Type()))
}
elem.Set(reflect.ValueOf(plugin))
specs.Push(typedPlugin{
Plugin: elem.Interface().(testoplugin.Plugin),
Type: elem.Type(),
})
}
v = reflectutil.Elem(v)
if v.Kind() != reflect.Struct {
return
}
// special case - we do not go deeper than that.
if v.Type() == reflect.TypeFor[T]() {
return
}
for i := range v.NumField() {
field := v.Field(i)
if !v.Type().Field(i).IsExported() {
continue
}
if field.Kind() != reflect.Pointer {
panic(
fmt.Sprintf(
"testo: all exported fields in T must be pointers, got: %s",
field.Type(),
),
)
}
setPlugins(field.Addr(), plugins, specs)
}
}
var pluginInterfaceType = reflect.TypeFor[testoplugin.Plugin]()
func collectPlugins(typ reflect.Type, plugins map[reflect.Type]struct{}) {
if typ.Implements(pluginInterfaceType) {
plugins[typ] = struct{}{}
}
typ = reflectutil.Elem(typ)
if typ.Kind() != reflect.Struct {
return
}
for i := range typ.NumField() {
field := typ.Field(i)
if field.IsExported() {
collectPlugins(field.Type, plugins)
}
}
}
func mapValues[K comparable, V any](m map[K]V) []V {
values := make([]V, 0, len(m))
for _, v := range m {
values = append(values, v)
}
return values
}
package testo
import (
"flag"
"regexp"
// so that cache flags are always available.
_ "github.com/ozontech/testo/testocache"
)
var flagMethod = flagRegexp{Regexp: regexp.MustCompile("")}
func init() {
flag.Var(&flagMethod, "testo.m", "regular expression to select tests of the testo suite to run")
}
var _ flag.Value = (*flagRegexp)(nil)
type flagRegexp struct{ *regexp.Regexp }
func (f *flagRegexp) Set(s string) error {
exp, err := regexp.Compile(s)
if err != nil {
return err
}
f.Regexp = exp
return nil
}
func (f *flagRegexp) String() string {
if f.Regexp == nil {
return ""
}
return f.Regexp.String()
}
package reflectutil
import (
"reflect"
)
type canElem[Self any] interface {
Kind() reflect.Kind
Elem() Self
}
// Elem unwraps the underlying elem of the pointer.
//
// Nested pointers are also supported - e.g. given "****value" it will return "value".
//
// Non-pointer values will be returned as is.
func Elem[T canElem[T]](v T) T {
for v.Kind() == reflect.Pointer {
v = v.Elem()
}
return v
}
// NameOf returns name of the underlying type T.
func NameOf[T any]() string {
t := reflect.TypeFor[T]()
return Elem(t).Name()
}
// New a new zero value of T.
//
// As a special case for pointers it will
// return pointer to the zero value of T (not nil).
func New[T any]() T {
t := reflect.TypeFor[T]()
var zero T
if t.Kind() == reflect.Pointer {
elem := reflect.ValueOf(&zero).Elem()
elem.Set(reflect.New(t.Elem()))
}
return zero
}
// Filled returns a new value T with all the exported pointer fields recursively set to non-nil zero values.
// That is, if type is a struct and contains field *int it will be set to &0.
// That logic is also applies for nested structs.
func Filled(hint reflect.Type) (reflect.Value, bool) {
value := reflect.New(hint)
ok := Fill(value)
return value.Elem(), ok
}
func Fill(v reflect.Value) bool {
return fill(v, make(map[reflect.Type]bool))
}
func FindRecursiveType(t reflect.Type) reflect.Type {
return findRecursiveType(t, make(map[reflect.Type]bool))
}
func findRecursiveType(t reflect.Type, visited map[reflect.Type]bool) reflect.Type {
switch t.Kind() {
case reflect.Pointer:
return findRecursiveType(t.Elem(), visited)
case reflect.Struct:
if visited[t] {
return t
}
visited[t] = true
for i := range t.NumField() {
field := t.Field(i)
if !field.IsExported() {
continue
}
if found := findRecursiveType(field.Type, visited); found != nil {
return found
}
}
visited[t] = false
return nil
default:
return nil
}
}
// fill returns true on success, meaning that type recursion was not detected.
func fill(v reflect.Value, visited map[reflect.Type]bool) bool {
if !v.IsValid() {
return true
}
switch v.Kind() {
case reflect.Pointer:
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}
return fill(v.Elem(), visited)
case reflect.Struct:
typ := v.Type()
if visited[typ] {
return false
}
visited[typ] = true
for i := range v.NumField() {
field := v.Field(i)
if !field.CanSet() {
continue
}
if !fill(field, visited) {
return false
}
}
visited[typ] = false
return true
default:
return true
}
}
package stack
import (
"slices"
)
// Stack is a first-in-last-out (FILO) data structure.
type Stack[T any] struct {
values []T
}
func New[T any]() Stack[T] {
return Stack[T]{}
}
func (s *Stack[T]) Len() int {
return len(s.values)
}
func (s *Stack[T]) Push(v T) {
s.values = append(s.values, v)
}
func (s *Stack[T]) Pop() (T, bool) {
if len(s.values) == 0 {
return *new(T), false
}
last := s.values[len(s.values)-1]
s.values = slices.Delete(s.values, len(s.values)-1, len(s.values))
return last, true
}
// Package testnamer copies test naming functionality as seen in go test.
//
// See [official go implementation].
//
// [official go implementation]: https://github.com/golang/go/blob/master/src/testing/match.go
package testnamer
import (
"fmt"
"strconv"
"strings"
"sync"
)
// Namer sanitizes, uniques, and filters names of subtests.
type Namer struct {
mu sync.Mutex
// subNames is used to deduplicate subtest names.
// Each key is the subtest name joined to the deduplicated name of the parent test.
// Each value is the count of the number of occurrences of the given subtest name
// already seen.
subNames map[string]int32
}
func New() *Namer {
return &Namer{
subNames: make(map[string]int32),
}
}
func (m *Namer) Name(parent, subname string) string {
name := subname
m.mu.Lock()
defer m.mu.Unlock()
if parent != "" {
name = m.unique(parent, rewrite(subname))
}
return name
}
// unique creates a unique name for the given parent and subname by affixing it
// with one or more counts, if necessary.
func (m *Namer) unique(parent, subname string) string {
base := parent + "/" + subname
for {
n := m.subNames[base]
if n < 0 {
panic("subtest count overflow")
}
m.subNames[base] = n + 1
if n == 0 && subname != "" {
prefix, nn := parseSubtestNumber(base)
if len(prefix) < len(base) && nn < m.subNames[prefix] {
// This test is explicitly named like "parent/subname#NN",
// and #NN was already used for the NNth occurrence of "parent/subname".
// Loop to add a disambiguating suffix.
continue
}
return base
}
name := fmt.Sprintf("%s#%02d", base, n)
if m.subNames[name] != 0 {
// This is the nth occurrence of base, but the name "parent/subname#NN"
// collides with the first occurrence of a subtest *explicitly* named
// "parent/subname#NN". Try the next number.
continue
}
return name
}
}
// parseSubtestNumber splits a subtest name into a "#%02d"-formatted int32
// suffix (if present), and a prefix preceding that suffix (always).
func parseSubtestNumber(s string) (prefix string, nn int32) {
i := strings.LastIndex(s, "#")
if i < 0 {
return s, 0
}
prefix, suffix := s[:i], s[i+1:]
if len(suffix) < 2 || (len(suffix) > 2 && suffix[0] == '0') {
// Even if suffix is numeric, it is not a possible output of a "%02" format
// string: it has either too few digits or too many leading zeroes.
return s, 0
}
if suffix == "00" {
if !strings.HasSuffix(prefix, "/") {
// We only use "#00" as a suffix for subtests named with the empty
// string — it isn't a valid suffix if the subtest name is non-empty.
return s, 0
}
}
n, err := strconv.ParseInt(suffix, 10, 32)
if err != nil || n < 0 {
return s, 0
}
return prefix, int32(n)
}
// rewrite rewrites a subname to having only printable characters and no white
// space.
func rewrite(s string) string {
var b []byte
for _, r := range s {
switch {
case isSpace(r):
b = append(b, '_')
case !strconv.IsPrint(r):
s := strconv.QuoteRune(r)
b = append(b, s[1:len(s)-1]...)
default:
b = append(b, string(r)...)
}
}
return string(b)
}
func isSpace(r rune) bool {
if r < 0x2000 {
switch r {
// Note: not the same as Unicode Z class.
case '\t', '\n', '\v', '\f', '\r', ' ', 0x85, 0xA0, 0x1680:
return true
}
} else {
if r <= 0x200a {
return true
}
switch r {
case 0x2028, 0x2029, 0x202f, 0x205f, 0x3000:
return true
}
}
return false
}
package testo
import (
"slices"
"sync"
"github.com/ozontech/testo/testoplugin"
)
// Avoid using these values directly.
// Use [Options] and [getDefaultOptions] instead.
var (
globalOptions []testoplugin.Option
globalOptionsMutex sync.RWMutex
)
// Options appends given options to the global options.
//
// Global options are prepended to each [RunSuite] call.
//
// func init() {
// testo.Options(myplugin.OutputDir("..."))
// }
func Options(options ...testoplugin.Option) {
globalOptionsMutex.Lock()
defer globalOptionsMutex.Unlock()
globalOptions = append(globalOptions, options...)
}
func getOptions() []testoplugin.Option {
globalOptionsMutex.RLock()
defer globalOptionsMutex.RUnlock()
return slices.Clone(globalOptions)
}
package testo
import (
"cmp"
"slices"
"testing"
"github.com/ozontech/testo/testoplugin"
"github.com/ozontech/testo/testoreflect"
)
// mergeSpecs multiple plugin specs into one.
func mergeSpecs(tb testing.TB, plugins ...testoplugin.Spec) testoplugin.Spec {
tb.Helper()
plans := make([]testoplugin.Plan, 0, len(plugins))
hooks := make([]testoplugin.Hooks, 0, len(plugins))
overrides := make([]testoplugin.Overrides, 0, len(plugins))
for _, p := range plugins {
plans = append(plans, p.Plan)
hooks = append(hooks, p.Hooks)
overrides = append(overrides, p.Overrides)
}
return testoplugin.Spec{
Plan: mergePlans(tb, plans...),
Hooks: mergeHooks(tb, hooks...),
Overrides: mergeOverrides(overrides...),
}
}
func mergePlans(tb testing.TB, plans ...testoplugin.Plan) testoplugin.Plan {
tb.Helper()
return testoplugin.Plan{
Prepare: func(suite testoreflect.SuiteInfo, tests *[]testoplugin.PlannedTest) {
tb.Helper()
// We could've break the loop when len(tests) == 0
// but it may be useful if some plugin would want to throw some warning or error
// when len(tests) == 0. Something like NoEmptySuitesPlugin.
for _, p := range plans {
if p.Prepare != nil {
p.Prepare(suite, tests)
}
}
},
}
}
func mergeHooks(tb testing.TB, hooks ...testoplugin.Hooks) testoplugin.Hooks {
tb.Helper()
beforeAll := make([]testoplugin.Hook, 0, len(hooks))
beforeEach := make([]testoplugin.Hook, 0, len(hooks))
beforeEachSub := make([]testoplugin.Hook, 0, len(hooks))
afterEachSub := make([]testoplugin.Hook, 0, len(hooks))
afterEach := make([]testoplugin.Hook, 0, len(hooks))
afterAll := make([]testoplugin.Hook, 0, len(hooks))
for _, h := range hooks {
if h := h.BeforeAll; h.Func != nil {
beforeAll = append(beforeAll, h)
}
if h := h.BeforeEach; h.Func != nil {
beforeEach = append(beforeEach, h)
}
if h := h.BeforeEachSub; h.Func != nil {
beforeEachSub = append(beforeEachSub, h)
}
if h := h.AfterEachSub; h.Func != nil {
afterEachSub = append(afterEachSub, h)
}
if h := h.AfterEach; h.Func != nil {
afterEach = append(afterEach, h)
}
if h := h.AfterAll; h.Func != nil {
afterAll = append(afterAll, h)
}
}
run := func(hooks []testoplugin.Hook) func() {
tb.Helper()
slices.SortStableFunc(hooks, func(a, b testoplugin.Hook) int {
return cmp.Compare(a.Priority, b.Priority)
})
return func() {
tb.Helper()
for _, h := range hooks {
runHook(tb, h)
}
}
}
return testoplugin.Hooks{
BeforeAll: testoplugin.Hook{Func: run(beforeAll)},
BeforeEach: testoplugin.Hook{Func: run(beforeEach)},
BeforeEachSub: testoplugin.Hook{Func: run(beforeEachSub)},
AfterEachSub: testoplugin.Hook{Func: run(afterEachSub)},
AfterEach: testoplugin.Hook{Func: run(afterEach)},
AfterAll: testoplugin.Hook{Func: run(afterAll)},
}
}
//nolint:funlen // splitting this into subfunctons would make it worse
func mergeOverrides(overrides ...testoplugin.Overrides) testoplugin.Overrides {
return testoplugin.Overrides{
Log: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncLog] {
return o.Log
},
),
Parallel: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncParallel] {
return o.Parallel
},
),
Setenv: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncSetenv] {
return o.Setenv
},
),
TempDir: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncTempDir] {
return o.TempDir
},
),
Deadline: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncDeadline] {
return o.Deadline
},
),
Context: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncContext] {
return o.Context
},
),
Chdir: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncChdir] {
return o.Chdir
},
),
Cleanup: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncCleanup] {
return o.Cleanup
},
),
Error: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncError] {
return o.Error
},
),
Skip: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncSkip] {
return o.Skip
},
),
SkipNow: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncSkipNow] {
return o.SkipNow
},
),
Skipped: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncSkipped] {
return o.Skipped
},
),
Fail: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncFail] {
return o.Fail
},
),
FailNow: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncFailNow] {
return o.FailNow
},
),
Failed: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncFailed] {
return o.Failed
},
),
Fatal: mergeOverride(
overrides,
func(o testoplugin.Overrides) testoplugin.Override[testoplugin.FuncFatal] {
return o.Fatal
},
),
}
}
func mergeOverride[F any](
overrides []testoplugin.Overrides,
getter func(testoplugin.Overrides) testoplugin.Override[F],
) func(F) F {
return func(f F) F {
for _, o := range overrides {
if override := getter(o); override != nil {
f = override(f)
}
}
return f
}
}
package testo
import (
"fmt"
"reflect"
"runtime/debug"
"testing"
"github.com/ozontech/testo/internal/reflectutil"
"github.com/ozontech/testo/internal/testnamer"
"github.com/ozontech/testo/testoplugin"
"github.com/ozontech/testo/testoreflect"
)
// parallelWrapperTest is the name of tests which
// wrap multiple (possibly parallel) tests to ensure
// hooks are executed properly.
//
// It should contain some special symbol which identifiers in Go
// cannot include (like exclamation mark), so that it won't collide with suite type name.
const parallelWrapperTest = "testo!"
// RunSuite will run the tests under the given suite.
//
// Test is defined as a suite method in the form of "TestXXX" or "Test"
// which accepts a single parameter of the same type as T passed to this function.
//
// It also accepts options for the plugins which can be used to configure those plugins.
// See [testoplugin.Option].
//
// RunSuite reports whether all suite tests succeeded.
func RunSuite[Suite suite[T], T CommonT](
testingT TestingT,
suite Suite,
options ...testoplugin.Option,
) bool {
testingT.Helper()
r := newRunner[Suite]()
return r.runSuite(testingT, suite, options...)
}
// Run runs f as a subtest of t called name. It runs f in a separate goroutine
// and blocks until f returns or calls t.Parallel to become a parallel test.
// Run reports whether f succeeded (or at least did not fail before calling t.Parallel).
//
// Run may be called simultaneously from multiple goroutines, but all such calls
// must return before the outer test function for t returns.
//
// WARN: Running this function during t.Cleanup panics.
func Run[T CommonT](
t T,
name string,
f func(t T),
options ...testoplugin.Option,
) bool {
t.Helper()
if f == nil {
f = func(T) {}
}
parentT := t
return parentT.unwrap().testingT.Run(name, func(testingT *testing.T) {
testingT.Helper()
t := construct(
testingT,
&parentT,
func(t *testoT) {
t.testNamer = parentT.unwrap().testNamer
t.reflection.Suite = parentT.unwrap().reflection.Suite
t.reflection.Test = testoreflect.RegularTestInfo{
Name: parentT.unwrap().testNamer.Name(parentT.unwrap().Name(), name),
RawBaseName: name,
Level: t.level(),
IsSubtest: true,
}
},
options...,
)
defer func() {
if r := recover(); r != nil {
trace := string(debug.Stack())
t.unwrap().reflection.Panic = &testoreflect.PanicInfo{
Value: r,
Trace: trace,
}
t.Fatalf("testo: test %q panicked: %v\n\n%s", t.Name(), r, trace)
}
}()
defer runHook(t, t.unwrap().spec.Hooks.AfterEachSub)
runHook(t, t.unwrap().spec.Hooks.BeforeEachSub)
f(t)
})
}
type runner[Suite suite[T], T CommonT] struct {
suiteName string
testNamer *testnamer.Namer
}
func newRunner[Suite suite[T], T CommonT]() runner[Suite, T] {
return runner[Suite, T]{
suiteName: reflectutil.NameOf[Suite](),
testNamer: testnamer.New(),
}
}
func (r *runner[Suite, T]) collectTests(t TestingT, caller string) suiteTests[Suite, T] {
t.Helper()
collector := testsCollector[Suite, T]{
CallerName: caller,
TestNamer: r.testNamer,
}
return collector.Collect(t)
}
func (r *runner[Suite, T]) runSuite(
testingT TestingT,
suite Suite,
options ...testoplugin.Option,
) bool {
testingT.Helper()
options = append(getOptions(), options...)
caller := r.testNamer.Name(testingT.Name(), r.suiteName)
tests := r.collectTests(testingT, caller)
suiteInfo := testoreflect.SuiteInfo{
Name: r.suiteName,
Caller: testingT.Name(),
TestingT: testingT,
Value: suite,
}
return testingT.Run(r.suiteName, func(testingT *testing.T) {
testingT.Helper()
t := construct[T](
testingT,
nil,
func(t *testoT) {
t.testNamer = r.testNamer
t.reflection.Suite = suiteInfo
t.reflection.Test = testoreflect.RegularTestInfo{
Name: caller,
RawBaseName: r.suiteName,
}
},
options...,
)
t.unwrap().logPlugins()
r.runSuiteTests(t, suite, tests)
})
}
func runHook(t testing.TB, h testoplugin.Hook) {
t.Helper()
if h.Func != nil {
h.Func()
}
}
func (r *runner[Suite, T]) runSuiteTests(t T, s Suite, tests suiteTests[Suite, T]) {
t.Helper()
defer func() {
if !t.Skipped() {
runHook(t, t.unwrap().spec.Hooks.AfterAll)
}
}()
runHook(t, t.unwrap().spec.Hooks.BeforeAll)
defer func() {
if !t.Skipped() {
s.AfterAll(t)
}
}()
s.BeforeAll(t)
suiteInfo := testoreflect.SuiteInfo{
Name: t.unwrap().reflection.Suite.Name,
Caller: t.unwrap().reflection.Suite.Caller,
TestingT: t.unwrap().reflection.Suite.TestingT,
Value: s,
Hooks: t.unwrap().reflection.Suite.Hooks,
}
allTests := r.applyPlan(
t,
suiteInfo,
tests.Collect(s),
)
t.unwrap().testingT.Run(parallelWrapperTest, func(testingT *testing.T) {
testingT.Helper()
for _, test := range allTests {
testingT.Run(test.Name, func(testingT *testing.T) {
innerT := construct(
testingT,
&t,
func(t *testoT) {
t.testNamer = r.testNamer
t.reflection.Suite = suiteInfo
t.reflection.Test = test.Info
},
test.Options...,
)
r.runSuiteTest(
innerT,
s,
test.suiteTest,
)
})
}
})
}
func (r *runner[Suite, T]) runSuiteTest(
t T,
s Suite,
test suiteTest[Suite, T],
) {
t.Helper()
defer func() {
if r := recover(); r != nil {
trace := string(debug.Stack())
t.unwrap().reflection.Panic = &testoreflect.PanicInfo{
Value: r,
Trace: trace,
}
t.Fatalf("testo: test %q panicked: %v\n\n%s", t.Name(), r, trace)
}
}()
defer runHook(t, t.unwrap().spec.Hooks.AfterEach)
runHook(t, t.unwrap().spec.Hooks.BeforeEach)
defer s.AfterEach(t)
s.BeforeEach(t)
test.Run(s, t)
}
func (r *runner[Suite, T]) applyPlan(
t T,
suiteInfo testoreflect.SuiteInfo,
tests []annotatedSuiteTest[Suite, T],
) []annotatedSuiteTest[Suite, T] {
t.Helper()
plannedTests := make([]testoplugin.PlannedTest, 0, len(tests))
for _, t := range tests {
plannedTests = append(plannedTests, plannedSuiteTest[Suite, T]{t})
}
if prepare := t.unwrap().spec.Plan.Prepare; prepare != nil {
prepare(suiteInfo, &plannedTests)
}
testsToReturn := make([]annotatedSuiteTest[Suite, T], 0, len(plannedTests))
for _, t := range plannedTests {
if t == nil {
continue
}
planned, ok := t.(plannedSuiteTest[Suite, T])
if !ok {
// must be unreachable because of "DoNotImplement" directive.
panic(fmt.Sprintf(
"testo: planned test is not of type %q",
reflect.TypeFor[plannedSuiteTest[Suite, T]](),
))
}
testsToReturn = append(testsToReturn, planned.inner)
}
return testsToReturn
}
package testo
type suite[T CommonT] interface {
// BeforeAll is called before all suite tests once.
// T is shared with a top-level suite test.
// Failing here will fail the entire suite.
BeforeAll(t T)
// BeforeEach is called before each suite test.
// T is shared with an actual test.
BeforeEach(t T)
// AfterEach is called after each suite test.
// T is shared with an actual test.
//
// WARN: this hook is defered to run at the end of the test.
// If that test has sub-tests marked as parallel,
// this hook will run BEFORE those sub-tests are finished.
//
// Unless you need to run sub-tests during this hook,
// it is recommended to use t.Cleanup during BeforeEach.
AfterEach(t T)
// AfterAll is called after all suite tests once.
// T is shared with a top-level suite test.
// Failing here will fail the entire suite.
AfterAll(t T)
// A private method to prevent users implementing the
// interface and so future additions to it will not
// violate backwards compatibility.
private()
}
// Suite is the base suite that all user-defined
// suites must embed.
//
// Suite may optionally provide hooks by implementing their methods:
//
// - BeforeAll(T) - is called before all suite tests once.
// - BeforeEach(T) - is called before each suite test. T is shared with an actual test.
// - AfterEach(T) - is called after each suite test. T is shared with an actual test.
// - AfterAll(T) - is called after all suite tests once.
//
// Example:
//
// type MySuite struct {
// testo.Suite[MyT]
// }
type Suite[T CommonT] struct {
// Mention *T in a field to disallow conversion between Pointer types.
// See go.dev/issue/56603 for more details.
// Use *T, not T, to avoid spurious recursive type definition errors.
_ [0]*T
}
// BeforeAll hook.
func (Suite[T]) BeforeAll(t T) {
t.unwrap().reflection.Suite.Hooks.MissedBeforeAll = true
}
// BeforeEach hook.
func (Suite[T]) BeforeEach(t T) {
t.unwrap().reflection.Suite.Hooks.MissedBeforeEach = true
}
// AfterEach hook.
func (Suite[T]) AfterEach(t T) {
t.unwrap().reflection.Suite.Hooks.MissedAfterEach = true
}
// AfterAll hook.
func (Suite[T]) AfterAll(t T) {
t.unwrap().reflection.Suite.Hooks.MissedAfterAll = true
}
//nolint:unused // sealed interface
func (Suite[T]) private() {}
//nolint:funcorder // private and public methods close to each other for readability
package testo
import (
"context"
"fmt"
"os"
"path/filepath"
"reflect"
"runtime"
"slices"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/ozontech/testo/internal/reflectutil"
"github.com/ozontech/testo/internal/testnamer"
"github.com/ozontech/testo/testoplugin"
"github.com/ozontech/testo/testoreflect"
)
type common interface {
testing.TB
Deadline() (deadline time.Time, ok bool)
Parallel()
}
// TestingT is an interface for [testing.T].
type TestingT interface {
common
Run(name string, f func(t *testing.T)) bool
}
// CommonT is the interface common for all [T] derivatives.
type CommonT interface {
common
unwrap() *T
}
type (
// T is a wrapper for [testing.T].
// This is a core entity in testo and used as a [testing.T] replacement.
//
// The common pattern is to embed it into new struct type:
//
// type MyT struct {
// *testo.T
// *SomePlugin
// }
//
// Plugins can also optionally embed it - testo will automatically initialize it
// by sharing the same value as an actual currently running test's T.
//
// type SomePlugin struct { *testo.T }
T struct {
common
// we double testing t interfaces,
// so that we still can have access for testing.T.Run
// but the user don't.
testingT TestingT
testNamer *testnamer.Namer
parent *T
spec testoplugin.Spec
// levelOptions stores options passed for the
// current level through [Run], [RunSuite] or [For].
levelOptions []testoplugin.Option
// reflection holds information for [Reflect].
reflection testoreflect.Reflection
failureSource atomicInt[testoreflect.TestFailureSource]
failureKind atomicInt[testoreflect.TestFailureKind]
hasFatalSubtest atomic.Bool
plugins map[reflect.Type]testoplugin.Plugin
}
testoT = T
)
// Plugin implements [testoplugin.Plugin].
//
// This is a placeholder to prevent other .Plugin methods being promoted.
func (*T) Plugin(testoplugin.Plugin, ...testoplugin.Option) testoplugin.Spec {
return testoplugin.Spec{}
}
// Context returns a context that is canceled just before
// Cleanup-registered functions are called.
//
// Cleanup functions can wait for any resources
// that shut down on [context.Context.Done] before the test completes.
func (t *T) Context() context.Context {
t.Helper()
return t.spec.Overrides.Context.Call(t.common.Context)()
}
// Parallel signals that this test is to be run in parallel with (and only with)
// other parallel tests. When a test is run multiple times due to use of
// -test.count or -test.cpu, multiple instances of a single test never run in
// parallel with each other.
func (t *T) Parallel() {
t.Helper()
t.spec.Overrides.Parallel.Call(t.common.Parallel)()
}
// Setenv calls os.Setenv(key, value) and uses Cleanup to
// restore the environment variable to its original value
// after the test.
//
// Because Setenv affects the whole process, it cannot be used
// in parallel tests or tests with parallel ancestors.
func (t *T) Setenv(key, value string) {
t.Helper()
t.spec.Overrides.Setenv.Call(t.setenv)(key, value)
}
// setenv is 1:1 copy from testing.common.Setenv.
// we don't use native setenv because that way we won't use
// overrides for methods such as fatal or cleanup.
func (t *T) setenv(key, value string) {
t.Helper()
prevValue, ok := os.LookupEnv(key)
if err := os.Setenv(key, value); err != nil {
t.Fatalf("cannot set environment variable: %v", err)
}
if ok {
t.Cleanup(func() { _ = os.Setenv(key, prevValue) })
return
}
t.Cleanup(func() { _ = os.Unsetenv(key) })
}
// TempDir returns a temporary directory for the test to use.
// The directory is automatically removed when the test and
// all its subtests complete.
// Each subsequent call to t.TempDir returns a unique directory;
// if the directory creation fails, TempDir terminates the test by calling Fatal.
func (t *T) TempDir() string {
t.Helper()
return t.spec.Overrides.TempDir.Call(t.common.TempDir)()
}
// Log formats its arguments using default formatting, analogous to Println,
// and records the text in the error log. For tests, the text will be printed only if
// the test fails or the -test.v flag is set. For benchmarks, the text is always
// printed to avoid having performance depend on the value of the -test.v flag.
func (t *T) Log(args ...any) {
t.Helper()
t.spec.Overrides.Log.Call(t.common.Log)(args...)
}
// Logf formats its arguments according to the format, analogous to Printf, and
// records the text in the error log. A final newline is added if not provided. For
// tests, the text will be printed only if the test fails or the -test.v flag is
// set. For benchmarks, the text is always printed to avoid having performance
// depend on the value of the -test.v flag.
func (t *T) Logf(format string, args ...any) {
t.Helper()
t.Log(fmt.Sprintf(format, args...))
}
// Deadline reports the time at which the test binary will have
// exceeded the timeout specified by the -timeout flag.
//
// By default, the ok result is false if the -timeout flag indicates "no timeout" (0).
func (t *T) Deadline() (time.Time, bool) {
t.Helper()
return t.spec.Overrides.Deadline.Call(t.common.Deadline)()
}
// Errorf is equivalent to Error with formatted message.
func (t *T) Errorf(format string, args ...any) {
t.Helper()
t.Error(fmt.Sprintf(format, args...))
}
// Error is equivalent to Log followed by Fail.
func (t *T) Error(args ...any) {
t.Helper()
t.spec.Overrides.Error.Call(t.error)(args...)
}
func (t *T) error(args ...any) {
t.Helper()
t.Log(args...)
t.Fail()
}
// Skip is equivalent to Log followed by SkipNow.
func (t *T) Skip(args ...any) {
t.Helper()
t.spec.Overrides.Skip.Call(t.skip)(args...)
}
func (t *T) skip(args ...any) {
t.Helper()
t.Log(args...)
t.SkipNow()
}
// SkipNow marks the test as having been skipped and stops its execution
// by calling [runtime.Goexit].
// If a test fails (see Error, Errorf, Fail) and is then skipped,
// it is still considered to have failed.
// Execution will continue at the next test or benchmark. See also FailNow.
// SkipNow must be called from the goroutine running the test, not from
// other goroutines created during the test. Calling SkipNow does not stop
// those other goroutines.
func (t *T) SkipNow() {
t.Helper()
t.spec.Overrides.SkipNow.Call(t.common.SkipNow)()
}
// Skipf is equivalent to Skip with formatted message.
func (t *T) Skipf(format string, args ...any) {
t.Helper()
t.Skip(fmt.Sprintf(format, args...))
}
// Skipped reports whether the test was skipped.
func (t *T) Skipped() bool {
t.Helper()
return t.spec.Overrides.Skipped.Call(t.common.Skipped)()
}
// Fail marks the function as having failed but continues execution.
func (t *T) Fail() {
t.Helper()
t.spec.Overrides.Fail.Call(t.fail)()
}
func (t *T) fail() {
t.Helper()
t.markFailure(testoreflect.TestFailureKindSoft)
t.common.Fail()
}
// FailNow marks the function as having failed and stops its execution
// by calling runtime.Goexit (which then runs all deferred calls in the
// current goroutine).
// Execution will continue at the next test or benchmark.
// FailNow must be called from the goroutine running the
// test or benchmark function, not from other goroutines
// created during the test. Calling FailNow does not stop
// those other goroutines.
func (t *T) FailNow() {
t.Helper()
t.spec.Overrides.FailNow.Call(t.failNow)()
}
func (t *T) failNow() {
t.Helper()
t.markFailure(testoreflect.TestFailureKindFatal)
t.common.FailNow()
}
// Failed reports whether the function has failed.
func (t *T) Failed() bool {
t.Helper()
return t.spec.Overrides.Failed.Call(t.common.Failed)()
}
// Fatal is equivalent to Log followed by FailNow.
func (t *T) Fatal(args ...any) {
t.Helper()
t.spec.Overrides.Fatal.Call(t.fatal)(args...)
}
func (t *T) fatal(args ...any) {
t.Helper()
t.Log(args...)
t.FailNow()
}
// Fatalf is equivalent to Fatal with formatted message.
func (t *T) Fatalf(format string, args ...any) {
t.Helper()
t.Fatal(fmt.Sprintf(format, args...))
}
// Chdir calls [os.Chdir] and uses Cleanup to restore the current
// working directory to its original value after the test. On Unix, it
// also sets PWD environment variable for the duration of the test.
//
// Because Chdir affects the whole process, it cannot be used
// in parallel tests or tests with parallel ancestors.
func (t *T) Chdir(dir string) {
t.Helper()
t.spec.Overrides.Chdir.Call(t.chdir)(dir)
}
// chdir is 1:1 copy from testing.common.Chdir.
// we don't use native chdir because that way we won't use
// overrides for methods such as fatal or cleanup.
func (t *T) chdir(dir string) {
t.Helper()
oldwd, err := os.Open(".")
if err != nil {
t.Fatal(err)
}
if err := os.Chdir(dir); err != nil {
t.Fatal(err)
}
// On POSIX platforms, PWD represents "an absolute pathname of the
// current working directory." Since we are changing the working
// directory, we should also set or update PWD to reflect that.
switch runtime.GOOS {
case "windows", "plan9":
// Windows and Plan 9 do not use the PWD variable.
default:
if !filepath.IsAbs(dir) {
dir, err = os.Getwd()
if err != nil {
t.Fatal(err)
}
}
t.Setenv("PWD", dir)
}
t.Cleanup(func() {
err := oldwd.Chdir()
_ = oldwd.Close()
if err != nil {
// It's not safe to continue with tests if we can't
// get back to the original working directory. Since
// we are holding a dirfd, this is highly unlikely.
panic("testo.Chdir: " + err.Error())
}
})
}
// Cleanup registers a function to be called when the test (or subtest) and all its
// subtests complete. Cleanup functions will be called in last added,
// first called order.
func (t *T) Cleanup(f func()) {
t.Helper()
t.spec.Overrides.Cleanup.Call(t.common.Cleanup)(f)
}
// Name returns the name of the running (sub-) test or benchmark.
//
// The name will include the name of the test along with the names of
// any nested sub-tests. If two sibling sub-tests have the same name,
// Name will append a suffix to guarantee the returned name is unique.
func (t *T) Name() string {
t.Helper()
return t.reflection.Test.GetName()
}
// unwrap the underlying T.
//
// It works since T's are embedded in user-defined structs.
func (t *T) unwrap() *T {
return t
}
// level indicates how deep this t is.
// That is, it shows the number of parents it has and zero if none.
func (t *T) level() int {
var level int
parent := t.parent
for parent != nil {
level++
parent = parent.parent
}
return level
}
// markFailure sets failure kind to the current t
// and promotes it for all ancestors.
func (t *T) markFailure(kind testoreflect.TestFailureKind) {
if kind == testoreflect.TestFailureKindNone {
return
}
t.failureKind.Store(kind)
t.failureSource.Store(testoreflect.TestFailureSourceSelf)
parent := t.parent
for parent != nil {
// parent may already have fatal failure,
// so we overwrite parent failure kind only if it has none.
parent.failureKind.CompareAndSwap(
testoreflect.TestFailureKindNone,
testoreflect.TestFailureKindSoft,
)
parent.failureSource.CompareAndSwap(
testoreflect.TestFailureSourceNone,
testoreflect.TestFailureSourceChild,
)
if kind == testoreflect.TestFailureKindFatal {
parent.hasFatalSubtest.Store(true)
}
parent = parent.parent
}
}
func (t *T) options() []testoplugin.Option {
size := len(t.levelOptions)
byLevel := [][]testoplugin.Option{t.levelOptions}
parent := t.parent
for parent != nil {
level := make([]testoplugin.Option, 0, len(parent.levelOptions))
for _, o := range parent.levelOptions {
if o.Propagate {
level = append(level, o)
}
}
byLevel = append(byLevel, level)
size += len(level)
parent = parent.parent
}
options := make([]testoplugin.Option, 0, size)
// so that child options come after parent options
for _, level := range slices.Backward(byLevel) {
options = append(options, level...)
}
return options
}
func (t *T) pluginNames() []string {
names := make([]string, 0, len(t.unwrap().plugins))
for typ := range t.unwrap().plugins {
if typ == reflect.TypeFor[*T]() {
continue
}
names = append(names, reflectutil.Elem(typ).String())
}
slices.Sort(names)
return names
}
func (t *T) logPlugins() {
t.Helper()
names := t.unwrap().pluginNames()
if len(names) == 0 {
return
}
t.testingT.Logf(
"testo: plugins collected: %d: %s\n",
len(names),
strings.Join(names, ", "),
)
}
// Reflect returns meta information about given t.
//
// You can reflect over any test by accessing its T instance:
//
// func (Suite) TestFoo(t T) {
// r := testo.Reflect(t)
// // r stores Reflection struct.
// }
//
// Same logic applies for plugins.
// If a plugin embeds `*testo.T` it can call the same [testo.Reflect] function:
//
// type Plugin struct{ *testo.T }
//
// func (p *Plugin) Plugin(parent testoplugin.Plugin, options ...testoplugin.Options) testoplugin.Spec {
// return testoplugin.Spec{
// Hooks: testoplugin.Hooks{
// BeforeEach: testoplugin.Hook{
// Func: func() { testo.Reflect(p) }
// }
// }
// }
// }
func Reflect(t CommonT) testoreflect.Reflection {
t.Helper()
internal := t.unwrap()
info := internal.reflection
info.FailureSource = internal.failureSource.Load()
info.FailureKind = internal.failureKind.Load()
info.HasFatalSubtest = internal.hasFatalSubtest.Load()
info.TestingT = internal.testingT
return info
}
type atomicInt[T ~int | ~int8 | ~int32 | ~int64] atomic.Int64
func (a *atomicInt[T]) Load() T {
return T((*atomic.Int64)(a).Load())
}
func (a *atomicInt[T]) Store(value T) {
(*atomic.Int64)(a).Store(int64(value))
}
func (a *atomicInt[T]) CompareAndSwap(oldvalue, newvalue T) bool {
return (*atomic.Int64)(a).CompareAndSwap(int64(oldvalue), int64(newvalue))
}
// Package testocache provides caching primitives to be used by
// external plugins.
//
// By default, it stores cache in a directory "$TWD/.testo_cache",
// where "$TWD" refers to the "test working directory" (not necessary a project root).
// Usually, this is a directory where "_test.go" file, which calls this package, is located.
//
// Can be overridden passing "-cache.dir ~/My/Dir" flag to the "go test"
// command OR (with lesser priority) with environment variable "TESTO_CACHE_DIR".
//
// Caching can also be disabled with flag "-cache.disable" or environtment
// variable "TESTO_CACHE_DISABLE" (e.g. "=true").
package testocache
import (
"cmp"
"errors"
"flag"
"io/fs"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
"unicode"
)
var (
flagDir = flag.String(
"cache.dir",
cmp.Or(os.Getenv("TESTO_CACHE_DIR"), ".testo_cache"),
"directory where the testo cache is stored",
)
flagDisable = flag.Bool(
"cache.disable",
parseBool(os.Getenv("TESTO_CACHE_DISABLE")),
"disable caching in testo",
)
)
// ErrDisabled indicates that caching is disabled.
var ErrDisabled = errors.New("cache is disabled")
const (
permFile os.FileMode = 0o600
permDir os.FileMode = 0o750
)
// Disabled returns true if caching is disabled.
// It's up to the package user to handle disabled state,
// e.g. do not save objects in cache when this function returns true.
func Disabled() bool {
return *flagDisable
}
var kvMu sync.RWMutex
// Keys returns all glob-matched keys by the given pattern.
// E.g. "myplugin-prefix-*"
//
// If cache is disabled (see [Disabled]), this function returns [ErrDisabled].
func Keys(pattern string) (keys []string, err error) {
dir, err := cacheDir()
if err != nil {
return nil, err
}
kvMu.RLock()
defer kvMu.RUnlock()
return fs.Glob(os.DirFS(dir), pattern)
}
// Get cached object by the given key.
//
// If cache is disabled (see [Disabled]), this function returns [ErrDisabled].
func Get(key string) ([]byte, error) {
dir, err := cacheDir()
if err != nil {
return nil, err
}
kvMu.RLock()
defer kvMu.RUnlock()
path := filepath.Join(dir, sanitizeFilename(key))
return os.ReadFile(path)
}
// Set saves value to cache with the given key.
//
// If cache is disabled (see [Disabled]), this function returns [ErrDisabled].
func Set(key string, value []byte) error {
dir, err := cacheDir()
if err != nil {
return err
}
kvMu.Lock()
defer kvMu.Unlock()
path := filepath.Join(dir, sanitizeFilename(key))
return os.WriteFile(path, value, permFile)
}
// Remove object from cache by the given key.
//
// If cache is disabled (see [Disabled]), this function returns [ErrDisabled].
func Remove(key string) error {
dir, err := cacheDir()
if err != nil {
return err
}
kvMu.Lock()
defer kvMu.Unlock()
path := filepath.Join(dir, sanitizeFilename(key))
return os.Remove(path)
}
func cacheDir() (string, error) {
if Disabled() {
return "", ErrDisabled
}
dir := *flagDir
if err := os.MkdirAll(dir, permDir); err != nil {
return "", err
}
if err := os.WriteFile(filepath.Join(dir, ".gitignore"), []byte("/*"), permFile); err != nil {
return "", err
}
return dir, nil
}
func parseBool(s string) bool {
b, _ := strconv.ParseBool(s)
return b
}
func sanitizeFilename(name string) string {
var sb strings.Builder
sb.Grow(len(name))
const (
invalid = `\/<>:\"|?*.`
replacement = '-'
)
for _, r := range name {
switch {
case r == 0, unicode.IsControl(r), strings.ContainsRune(invalid, r):
sb.WriteRune(replacement)
default:
sb.WriteRune(r)
}
}
return sb.String()
}
package testoplugin
import (
"context"
"time"
)
// Overrides defines all builtin methods of T a plugin can override.
//
// Overrides work using middleware pattern - multiple overrides are stacked on top of each other,
// passing by a "next" function.
//
// There exists a certain hierarchy what method calls what underneath.
// For example, overriding Log will affect Error, Skip, Fatal and their printf equivalents.
type Overrides struct {
Log Override[FuncLog]
Parallel Override[FuncParallel]
TempDir Override[FuncTempDir]
Deadline Override[FuncDeadline]
Context Override[FuncContext]
Cleanup Override[FuncCleanup]
// Setenv calls Cleanup to restore an environment variable.
// On error, it calls Fatal.
Setenv Override[FuncSetenv]
// Chdir calls Cleanup to restore a current directory.
// On error, it calls Fatal.
Chdir Override[FuncChdir]
// Error calls Log followed by Fail.
Error Override[FuncError]
// Skip calls Log followed by SkipNow.
Skip Override[FuncSkip]
SkipNow Override[FuncSkipNow]
Skipped Override[FuncSkipped]
Fail Override[FuncFail]
FailNow Override[FuncFailNow]
Failed Override[FuncFailed]
// Fatal calls Log followed by FailNow.
Fatal Override[FuncFatal]
}
type (
// FuncParallel describes [testing.T.Parallel] method.
FuncParallel func()
// FuncSetenv describes [testing.T.Setenv] method.
FuncSetenv func(key, value string)
// FuncTempDir describes [testing.T.TempDir] method.
FuncTempDir func() string
// FuncLog describes [testing.T.Log] method.
FuncLog func(args ...any)
// FuncDeadline describes [testing.T.Deadline] method.
FuncDeadline func() (deadline time.Time, ok bool)
// FuncContext describes [testing.T.Context] method.
FuncContext func() context.Context
// FuncChdir describes [testing.T.Chdir] method.
FuncChdir func(dir string)
// FuncError describes [testing.T.Error] method.
FuncError func(args ...any)
// FuncSkip describes [testing.T.Skip] method.
FuncSkip func(args ...any)
// FuncSkipNow describes [testing.T.SkipNow] method.
FuncSkipNow func()
// FuncSkipped describes [testing.T.Skipped] method.
FuncSkipped func() bool
// FuncFail describes [testing.T.Fail] method.
FuncFail func()
// FuncFailNow describes [testing.T.FailNow] method.
FuncFailNow func()
// FuncFailed describes [testing.T.Failed] method.
FuncFailed func() bool
// FuncFatal describes [testing.T.Fatal] method.
FuncFatal func(args ...any)
// FuncCleanup describes [testing.T.Cleanup] method.
FuncCleanup func(f func())
)
// Override for the function.
//
// Nil value is valid and represents absence of override.
//
// Use [Override.Call] to safely call an override.
type Override[F any] func(f F) F
// Call returns an overridden f.
// If override is nil f is returned as is.
func (o Override[F]) Call(f F) F {
if o == nil {
return f
}
return o(f)
}
// Package testoplugin provides plugin primitives for using plugins in testo.
//
// # Implementing a plugin
//
// Plugins can implement [Plugin] interface to be registered as such.
//
// Method "Plugin" will be called for each plugin before running a suite.
// Parent is nil for top-level tests. For sub-tests it referes to the plugin instance of the parent test.
//
// It is encouraged to ensure a plugin implements [Plugin] interface with the following line:
//
// var _ testoplugin.Plugin = (*PluginFoo)(nil)
package testoplugin
import "math"
// Priority defines execution order (priority).
// It defines when a plugin part should be invoked when other parts are available.
//
// "Plugin part" means plan, hook, override, etc.
// Right now, only [Hook]s are supported.
type Priority int
const (
// TryFirst indicates that this plugin part should be run as early as possible.
TryFirst Priority = math.MinInt
// TryLast indicates that this plugin part should be run as late as possible.
TryLast Priority = math.MaxInt
)
// Option is used to configure plugin upon creation.
//
// All user-supplied options are passed to the Plugin method for each plugin.
// It is a plugin responsibility to check if the given option corresponds to it.
// One way to check it is with type assertion:
//
// var opt Option
// o, ok := opt.Value.(MyPluginOption)
type Option struct {
// Value of this option.
Value any
// Propagate states whether this option
// should be passed automatically to all subtests.
Propagate bool
}
// Propagated returns a shallow clone of this option
// with "Propagate" field set to true.
func (o Option) Propagated() Option {
o.Propagate = true
return o
}
// Plugin is an interface that plugins implement to provide
// [Plan], [Hooks] and [Overrides] to the tests.
type Plugin interface {
Plugin(parent Plugin, options ...Option) Spec
}
// Spec is a plugin specification.
type Spec struct {
Plan Plan
Hooks Hooks
Overrides Overrides
}
// Package testoreflect provides reflection primitives for T.
//
// You can obtain [Reflection] for any test by accessing its T instance:
//
// func (Suite) TestFoo(t T) {
// r := testo.Reflect(t)
// // r stores Reflection struct.
// }
//
// Same logic applies for plugins.
// If a plugin embeds `*testo.T` it can call the same [testo.Reflect] function:
//
// type Plugin struct{ *testo.T }
//
// func (p *Plugin) Plugin(parent testoplugin.Plugin, options ...testoplugin.Options) testoplugin.Spec {
// return testoplugin.Spec{
// Hooks: testoplugin.Hooks{
// BeforeEach: testoplugin.Hook{
// Func: func() { testo.Reflect(p) }
// }
// }
// }
// }
package testoreflect
import (
"fmt"
"testing"
"time"
)
// Reflection for T.
type Reflection struct {
// Suite info which this test belongs to.
Suite SuiteInfo
// Test holds information about current test.
Test TestInfo
// Panic is panic information.
// It is nil if the test did not panic.
Panic *PanicInfo
// FailureKind is test failure kind.
FailureKind TestFailureKind
// FailureSource is the test failure source.
FailureSource TestFailureSource
// HasFatalSubtest is true if any (nested) subtest of this test
// has [TestFailureKind] equal to [TestFailureKindFatal].
HasFatalSubtest bool
// TestingT provides access to real [testing.T]
// without overrides or hooks applied.
TestingT TestingT
}
// TestingT is an interface for [testing.T].
type TestingT interface {
testing.TB
Deadline() (deadline time.Time, ok bool)
Parallel()
Run(name string, f func(t *testing.T)) bool
}
// TestFailureSource states where test failure appeared.
type TestFailureSource int
const (
// TestFailureSourceNone states that test did not fail.
TestFailureSourceNone TestFailureSource = iota
// TestFailureSourceSelf states that test failed itself.
TestFailureSourceSelf
// TestFailureSourceChild states that test failed because of child failure.
TestFailureSourceChild
)
// String implements [fmt.Stringer].
func (s TestFailureSource) String() string {
switch s {
case TestFailureSourceChild:
return "child"
case TestFailureSourceNone:
return "none"
case TestFailureSourceSelf:
return "self"
default:
return fmt.Sprintf("TestFailureSource(%d)", s)
}
}
// TestFailureKind defines test failure kinds.
type TestFailureKind int
const (
// TestFailureKindNone states that test did not fail.
TestFailureKindNone TestFailureKind = iota
// TestFailureKindSoft states that test failed but it was not fatal.
// For example, t.Fail() was called.
TestFailureKindSoft
// TestFailureKindFatal states that test failed with fatal error.
// For example, t.FailNow() or t.Fatal() were called.
TestFailureKindFatal
)
// String implements [fmt.Stringer].
func (k TestFailureKind) String() string {
switch k {
case TestFailureKindFatal:
return "fatal"
case TestFailureKindNone:
return "none"
case TestFailureKindSoft:
return "soft"
default:
return fmt.Sprintf("TestFailureKind(%d)", k)
}
}
// PanicInfo holds information for recovered panic.
type PanicInfo struct {
// Value returned by recover().
Value any
// Trace is a stack trace for this panic.
Trace string
}
// TestInfo is a enum which is
// either [ParametrizedTestInfo] or [RegularTestInfo].
//
// switch info := info.(type) {
// case testoreflect.ParametrizedTestInfo:
// case testoreflect.RegularTestInfo:
// }
type TestInfo interface {
// GetName returns full test name as would be returned by t.Name().
GetName() string
// GetLevel returns test level (depth).
//
// func (Suite) BeforeAll(t T) {} // level 0
//
// func (Suite) BeforeEach(t T) {} // level 1
//
// func (Suite) Test(t T) { // level 1
// testo.Run(t, "...", func(t T) { // level 2
// testo.Run(t, "...", func(t T) { // level 3
// })
// })
// }
GetLevel() int
isTestInfo()
}
// ParametrizedTestInfo is the information about parametrized test.
type ParametrizedTestInfo struct {
// Name is a full test name as would be returned by t.Name().
Name string
// BaseName of the test.
BaseName string
// Index is the 0-based case number of this test.
Index int
// CasesCount is the total cases count for this test.
CasesCount int
// Params passed for the current test case.
Params map[string]any
// FuncPC is the program counter (PC) of this test function.
//
// NOTE: it may be empty (zero) in some cases, e.g. hooks and sub-tests.
// Use with caution.
FuncPC uintptr
}
// GetName returns value of [ParametrizedTestInfo.Name] field.
func (i ParametrizedTestInfo) GetName() string { return i.Name }
// GetLevel always returns level 1 since since parametrized tests can't be nested.
func (ParametrizedTestInfo) GetLevel() int { return 1 }
func (ParametrizedTestInfo) isTestInfo() {}
// SuiteInfo is the information about suite.
type SuiteInfo struct {
// Name of this suite.
Name string
// Caller is the full test name from
// which this suite was run.
Caller string
// TestingT is the [testing.T] of the caller.
//
// func TestFoo(t *testing.T) {
// testo.RunSuite(t, new(Suite))
// // ^ that t
// }
TestingT TestingT
// Value holds suite value.
Value any
// Hooks information for this suite.
Hooks SuiteHooksInfo
}
// SuiteHooksInfo is suite hooks information.
type SuiteHooksInfo struct {
// MissedBeforeAll is true if suite did not call BeforeAll hook,
// meaning it was not defined.
MissedBeforeAll bool
// MissedBeforeEach is true if suite did not call BeforeEach hook.
// It is either not defined or failed beforehand.
MissedBeforeEach bool
// MissedAfterEach is true if suite did not call AfterEach hook.
// It is either not defined or failed beforehand.
MissedAfterEach bool
// MissedBeforeAll is true if suite did not call BeforeAll hook,
// meaning it was not defined.
MissedAfterAll bool
}
// RegularTestInfo is the information about regular (non-parametrized) test.
type RegularTestInfo struct {
// Name is a full test name as would be returned by t.Name().
Name string
// RawBaseName is the raw "unformatted" base name of this test.
//
// For example:
//
// Run(t, "my test name!?", func(...) { ... })
//
// t.Name() would equal to "SomeSuite/my_test_name",
// while this field would equal to "my test name!?" (the same as passed).
RawBaseName string
// Level indicates how deep this t is.
// That is, it shows the number of parents it has and zero if none.
//
// Zero level is Before/After-All hooks.
// Subsequent levels are for test methods or subtests ran from Before/After-All hooks.
//
// To differentiate subtest ran from Before/After-All hook from a test method see IsSubtest field.
Level int
// IsSubtest states if this test is a subtest.
//
// Example of a subtest:
// testo.Run(t, "...", func(T) {})
//
// Example of a test method (not subtest):
// func (Suite) TestFoo(T)
IsSubtest bool
// FuncPC is the program counter (PC) of this test function.
//
// NOTE: it may be empty (zero) in some cases, e.g. hooks and sub-tests.
// Use with caution.
FuncPC uintptr
}
// GetName returns value of [RegularTestInfo.Name] field.
func (i RegularTestInfo) GetName() string { return i.Name }
// GetLevel returns value of [RegularTestInfo.Level] field.
func (i RegularTestInfo) GetLevel() int { return i.Level }
func (RegularTestInfo) isTestInfo() {}