From f62506c40f7fa19d35df4e72adf818bc00dbb1ff Mon Sep 17 00:00:00 2001 From: Matthias <5011972+fasmat@users.noreply.github.com> Date: Thu, 30 Apr 2026 19:44:57 +0200 Subject: [PATCH] Allow pointer arguments to be passed to mg.F --- mage/main.go | 2 +- mg/deps.go | 12 +++++----- mg/deps_internal_test.go | 2 +- mg/errors.go | 4 ++-- mg/fn.go | 28 +++++++++++++++-------- mg/fn_test.go | 49 ++++++++++++++++++++++++++++++++++++---- 6 files changed, 72 insertions(+), 25 deletions(-) diff --git a/mage/main.go b/mage/main.go index b51ea95c..73962479 100644 --- a/mage/main.go +++ b/mage/main.go @@ -53,7 +53,7 @@ func lowerFirstWord(s string) string { return strings.ToLower(s) } -var mainfileTemplate = template.Must(template.New("").Funcs(map[string]interface{}{ +var mainfileTemplate = template.Must(template.New("").Funcs(map[string]any{ "lower": strings.ToLower, "lowerFirst": func(s string) string { parts := strings.Split(s, ":") diff --git a/mg/deps.go b/mg/deps.go index bed1a972..93d98266 100644 --- a/mg/deps.go +++ b/mg/deps.go @@ -52,7 +52,7 @@ var onces = &onceMap{ // SerialDeps is like Deps except it runs each dependency serially, instead of // in parallel. This can be useful for resource intensive dependencies that // shouldn't be run at the same time. -func SerialDeps(fns ...interface{}) { +func SerialDeps(fns ...any) { funcs := checkFns(fns) ctx := context.Background() for i := range fns { @@ -63,7 +63,7 @@ func SerialDeps(fns ...interface{}) { // SerialCtxDeps is like CtxDeps except it runs each dependency serially, // instead of in parallel. This can be useful for resource intensive // dependencies that shouldn't be run at the same time. -func SerialCtxDeps(ctx context.Context, fns ...interface{}) { +func SerialCtxDeps(ctx context.Context, fns ...any) { funcs := checkFns(fns) for i := range fns { runDeps(ctx, funcs[i:i+1]) @@ -86,7 +86,7 @@ func SerialCtxDeps(ctx context.Context, fns ...interface{}) { // their own dependencies using Deps. Each dependency is run in their own // goroutines. Each function is given the context provided if the function // prototype allows for it. -func CtxDeps(ctx context.Context, fns ...interface{}) { +func CtxDeps(ctx context.Context, fns ...any) { funcs := checkFns(fns) runDeps(ctx, funcs) } @@ -129,7 +129,7 @@ func runDeps(ctx context.Context, fns []Fn) { } } -func checkFns(fns []interface{}) []Fn { +func checkFns(fns []any) []Fn { funcs := make([]Fn, len(fns)) for i, f := range fns { if fn, ok := f.(Fn); ok { @@ -162,7 +162,7 @@ func checkFns(fns []interface{}) []Fn { // This is a way to build up a tree of dependencies with each dependency // defining its own dependencies. Functions must have the same signature as a // Mage target, i.e. optional context argument, optional error return. -func Deps(fns ...interface{}) { +func Deps(fns ...any) { CtxDeps(context.Background(), fns...) } @@ -182,7 +182,7 @@ func changeExit(old, nw int) int { } // funcName returns the unique name for the function. -func funcName(i interface{}) string { +func funcName(i any) string { return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name() } diff --git a/mg/deps_internal_test.go b/mg/deps_internal_test.go index b5f1d87a..ce49bcae 100644 --- a/mg/deps_internal_test.go +++ b/mg/deps_internal_test.go @@ -48,7 +48,7 @@ func TestDepWasNotInvoked(t *testing.T) { t.Fatalf(`expected to get "%s" but got "%s"`, wantErr, gotErr) } }() - func(fns ...interface{}) { + func(fns ...any) { checkFns(fns) }(fn1()) } diff --git a/mg/errors.go b/mg/errors.go index fc363291..fdf33050 100644 --- a/mg/errors.go +++ b/mg/errors.go @@ -21,7 +21,7 @@ type exitStatus interface { // Fatal returns an error that will cause mage to print out the // given args and exit with the given exit code. -func Fatal(code int, args ...interface{}) error { +func Fatal(code int, args ...any) error { return fatalError{ code: code, error: errors.New(fmt.Sprint(args...)), @@ -30,7 +30,7 @@ func Fatal(code int, args ...interface{}) error { // Fatalf returns an error that will cause mage to print out the // given message and exit with the given exit code. -func Fatalf(code int, format string, args ...interface{}) error { +func Fatalf(code int, format string, args ...any) error { return fatalError{ code: code, error: fmt.Errorf(format, args...), diff --git a/mg/fn.go b/mg/fn.go index 832728b6..1435c487 100644 --- a/mg/fn.go +++ b/mg/fn.go @@ -31,7 +31,7 @@ type Fn interface { // are declared by the function. Note that you do not need to and should not pass a context.Context // to F, even if the target takes a context. Compatible args are int, bool, string, and // time.Duration. -func F(target interface{}, args ...interface{}) Fn { +func F(target any, args ...any) Fn { hasContext, isNamespace, err := checkF(target, args) if err != nil { panic(err) @@ -103,7 +103,7 @@ func (f fn) Run(ctx context.Context) error { return f.f(ctx) } -func checkF(target interface{}, args []interface{}) (hasContext, isNamespace bool, _ error) { +func checkF(target any, args []any) (hasContext, isNamespace bool, _ error) { t := reflect.TypeOf(target) if t == nil || t.Kind() != reflect.Func { return false, false, fmt.Errorf("non-function passed to mg.F: %T. The mg.F function accepts function names, such as mg.F(TargetA, \"arg1\", \"arg2\")", target) @@ -181,16 +181,24 @@ var ( errType = reflect.TypeOf(func() error { return nil }).Out(0) emptyType = reflect.TypeOf(struct{}{}) - intType = reflect.TypeOf(int(0)) - stringType = reflect.TypeOf(string("")) - boolType = reflect.TypeOf(bool(false)) - durType = reflect.TypeOf(time.Second) + intType = reflect.TypeOf(int(0)) + intPtrType = reflect.TypeOf((*int)(nil)) + stringType = reflect.TypeOf(string("")) + stringPtrType = reflect.TypeOf((*string)(nil)) + boolType = reflect.TypeOf(bool(false)) + boolPtrType = reflect.TypeOf((*bool)(nil)) + durType = reflect.TypeOf(time.Second) + durPtrType = reflect.TypeOf((*time.Duration)(nil)) // don't put ctx in here, this is for non-context types. argTypes = map[reflect.Type]bool{ - intType: true, - boolType: true, - stringType: true, - durType: true, + intType: true, + intPtrType: true, + boolType: true, + boolPtrType: true, + stringType: true, + stringPtrType: true, + durType: true, + durPtrType: true, } ) diff --git a/mg/fn_test.go b/mg/fn_test.go index 95f2b58c..c65fefcf 100644 --- a/mg/fn_test.go +++ b/mg/fn_test.go @@ -95,7 +95,7 @@ func TestFuncCheck(t *testing.T) { t.Error("func is on a namespace") } - hasContext, isNamespace, err = checkF(Foo.CtxErrorArgs, []interface{}{1, "s", true, time.Second}) + hasContext, isNamespace, err = checkF(Foo.CtxErrorArgs, []any{1, "s", true, time.Second}) if err != nil { t.Error(err) } @@ -106,7 +106,7 @@ func TestFuncCheck(t *testing.T) { t.Error("func is on a namespace") } - hasContext, isNamespace, err = checkF(func(int, bool, string, time.Duration) {}, []interface{}{1, true, "s", time.Second}) + hasContext, isNamespace, err = checkF(func(int, bool, string, time.Duration) {}, []any{1, true, "s", time.Second}) if err != nil { t.Error(err) } @@ -128,7 +128,7 @@ func TestFuncCheck(t *testing.T) { t.Error("expected a nil function argument to be handled gracefully") } }() - _, _, err = checkF(nil, []interface{}{1, 2}) + _, _, err = checkF(nil, []any{1, 2}) if err == nil { t.Error("expected a nil function argument to be invalid") } @@ -136,7 +136,7 @@ func TestFuncCheck(t *testing.T) { func TestF(t *testing.T) { var ( - ctxOut interface{} + ctxOut any iOut int sOut string bOut bool @@ -221,6 +221,20 @@ func TestFNilError(t *testing.T) { } } +func TestFPointerArg(t *testing.T) { + value := new(int) + *value = 1776 + fn := F(func(i *int) { + if *i != 1776 { + t.Errorf("Wrong arg, got %d, want 1776", *i) + } + }, value) + err := fn.Run(context.Background()) + if err != nil { + t.Fatal(err) + } +} + func TestFVariadic(t *testing.T) { fn := F(func(args ...string) { if !reflect.DeepEqual(args, []string{"a", "b"}) { @@ -299,6 +313,31 @@ func TestFWrongArgType(t *testing.T) { F(func(int) {}, "not an int") } +func TestFValueToPointerArg(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic for value arg when pointer expected") + } + }() + i := 1776 + F(func(*int) {}, i) +} + +func TestFPointerToValueArg(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatal("expected panic for pointer arg when value expected") + } + }() + i := new(int) + *i = 1776 + F(func(int) {}, i) +} + +type Unsupported struct{} + func TestFUnsupportedArgType(t *testing.T) { defer func() { r := recover() @@ -306,7 +345,7 @@ func TestFUnsupportedArgType(t *testing.T) { t.Fatal("expected panic for unsupported arg type") } }() - F(func(*int) {}, (*int)(nil)) + F(func(*Unsupported) {}, (*Unsupported)(nil)) } func TestFTooManyReturns(t *testing.T) {