diff --git a/func.go b/func.go index 1a4552de..cb07d40c 100644 --- a/func.go +++ b/func.go @@ -60,7 +60,7 @@ func RegisterLibFunc(fptr any, handle uintptr, name string) { // int64 <=> int64_t // float32 <=> float // float64 <=> double -// struct <=> struct (WIP - darwin only) +// struct <=> struct (darwin amd64/arm64, linux amd64/arm64) // func <=> C function // unsafe.Pointer, *T <=> void* // []T => void* @@ -169,9 +169,7 @@ func RegisterFunc(fptr any, cfn uintptr) { stack++ } case reflect.Struct: - if runtime.GOOS != "darwin" || (runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64") { - panic("purego: struct arguments are only supported on darwin amd64 & arm64") - } + ensureStructSupportedForRegisterFunc() if arg.Size() == 0 { continue } @@ -190,9 +188,7 @@ func RegisterFunc(fptr any, cfn uintptr) { } } if ty.NumOut() == 1 && ty.Out(0).Kind() == reflect.Struct { - if runtime.GOOS != "darwin" { - panic("purego: struct return values only supported on darwin arm64 & amd64") - } + ensureStructSupportedForRegisterFunc() outType := ty.Out(0) checkStructFieldsSupported(outType) if runtime.GOARCH == "amd64" && outType.Size() > maxRegAllocStructSize { @@ -465,13 +461,23 @@ func checkStructFieldsSupported(ty reflect.Type) { switch f.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Uintptr, reflect.Ptr, reflect.UnsafePointer, reflect.Float64, reflect.Float32: + reflect.Uintptr, reflect.Ptr, reflect.UnsafePointer, reflect.Float64, reflect.Float32, + reflect.Bool: default: panic(fmt.Sprintf("purego: struct field type %s is not supported", f)) } } } +func ensureStructSupportedForRegisterFunc() { + if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" { + panic("purego: struct arguments are only supported on amd64 and arm64") + } + if runtime.GOOS != "darwin" && runtime.GOOS != "linux" { + panic("purego: struct arguments are only supported on darwin and linux") + } +} + func roundUpTo8(val uintptr) uintptr { return (val + 7) &^ 7 } diff --git a/struct_amd64.go b/struct_amd64.go index c4c2ad8f..8bbd6faa 100644 --- a/struct_amd64.go +++ b/struct_amd64.go @@ -169,9 +169,10 @@ func tryPlaceRegister(v reflect.Value, addFloat func(uintptr), addInt func(uintp } shift += 8 class |= _INTEGER - case reflect.Pointer: - ok = false - return + case reflect.Pointer, reflect.UnsafePointer: + val = uint64(f.Pointer()) + shift = 64 + class = _INTEGER case reflect.Int8: val |= uint64(f.Int()&0xFF) << shift shift += 8 @@ -241,7 +242,7 @@ func placeStack(v reflect.Value, addStack func(uintptr)) { for i := 0; i < v.Type().NumField(); i++ { f := v.Field(i) switch f.Kind() { - case reflect.Pointer: + case reflect.Pointer, reflect.UnsafePointer: addStack(f.Pointer()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: addStack(uintptr(f.Int())) diff --git a/struct_arm64.go b/struct_arm64.go index 8605e77b..c8fe16b6 100644 --- a/struct_arm64.go +++ b/struct_arm64.go @@ -177,7 +177,7 @@ func placeRegisters(v reflect.Value, addFloat func(uintptr), addInt func(uintptr shift = 0 flushed = true class = _NO_CLASS - case reflect.Ptr: + case reflect.Ptr, reflect.UnsafePointer: addInt(f.Pointer()) shift = 0 flushed = true diff --git a/struct_test.go b/struct_test.go index 4e3483b3..03cd83bf 100644 --- a/struct_test.go +++ b/struct_test.go @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: 2024 The Ebitengine Authors -//go:build darwin && (arm64 || amd64) +//go:build (darwin || linux) && (amd64 || arm64) package purego_test @@ -373,8 +373,8 @@ func TestRegisterFunc_structArgs(t *testing.T) { } var Array4CharsFn func(chars Array4Chars) int32 purego.RegisterLibFunc(&Array4CharsFn, lib, "Array4Chars") - const expectedSum = 1 + 2 + 4 + 8 - if ret := Array4CharsFn(Array4Chars{a: [...]int8{1, 2, 4, 8}}); ret != expectedSum { + const expectedSum = 123 + if ret := Array4CharsFn(Array4Chars{a: [...]int8{10, 20, 30, 63}}); ret != expectedSum { t.Fatalf("Array4CharsFn returned %d wanted %d", ret, expectedSum) } } @@ -499,6 +499,37 @@ func TestRegisterFunc_structArgs(t *testing.T) { t.Fatalf("FourInt32s returned %d wanted %d", result, want) } } + { + type PointerWrapper struct { + ctx unsafe.Pointer + } + var ExtractPointer func(wrapper PointerWrapper) uintptr + purego.RegisterLibFunc(&ExtractPointer, lib, "ExtractPointer") + + var v int + ptr := unsafe.Pointer(&v) + expected := uintptr(ptr) + result := ExtractPointer(PointerWrapper{ctx: ptr}) + if result != expected { + t.Fatalf("ExtractPointer returned %#x wanted %#x", result, expected) + } + } + { + type TwoPointers struct { + ptr1, ptr2 unsafe.Pointer + } + var AddPointers func(wrapper TwoPointers) uintptr + purego.RegisterLibFunc(&AddPointers, lib, "AddPointers") + + var v1, v2 int + ptr1 := unsafe.Pointer(&v1) + ptr2 := unsafe.Pointer(&v2) + expected := uintptr(ptr1) + uintptr(ptr2) + result := AddPointers(TwoPointers{ptr1, ptr2}) + if result != expected { + t.Fatalf("AddPointers returned %#x wanted %#x", result, expected) + } + } } func TestRegisterFunc_structReturns(t *testing.T) { @@ -757,6 +788,49 @@ func TestRegisterFunc_structReturns(t *testing.T) { t.Fatalf("ReturnMixed4 returned %+v wanted %+v", ret, expected) } } + { + type Mixed5 struct { + a *int64 + b int32 + c float32 + d int32 + } + var ReturnMixed5 func(a *int64, b int32, c float32, d int32) Mixed5 + purego.RegisterLibFunc(&ReturnMixed5, lib, "ReturnMixed5") + ptr := new(int64) + expected := Mixed5{ptr, 1, 7.2, 9} + if ret := ReturnMixed5(ptr, 1, 7.2, 9); ret != expected { + t.Fatalf("ReturnMixed5 returned %+v wanted %+v", ret, expected) + } + runtime.KeepAlive(ptr) + } + { + type SmallBool struct { + a bool + b int32 + c int64 + } + var ReturnSmallBool func(a bool, b int32, c int64) SmallBool + purego.RegisterLibFunc(&ReturnSmallBool, lib, "ReturnSmallBool") + expected := SmallBool{true, 42, 123456789} + if ret := ReturnSmallBool(true, 42, 123456789); ret != expected { + t.Fatalf("ReturnSmallBool returned %+v wanted %+v", ret, expected) + } + } + { + type LargeBool struct { + a bool + b int32 + c int64 + d int64 + } + var ReturnLargeBool func(a bool, b int32, c int64, d int64) LargeBool + purego.RegisterLibFunc(&ReturnLargeBool, lib, "ReturnLargeBool") + expected := LargeBool{false, -99, 987654321, 111222333444} + if ret := ReturnLargeBool(false, -99, 987654321, 111222333444); ret != expected { + t.Fatalf("ReturnLargeBool returned %+v wanted %+v", ret, expected) + } + } { type Ptr1 struct { a *int64 diff --git a/testdata/structtest/struct_test.c b/testdata/structtest/struct_test.c index 4cbf8060..27d45267 100644 --- a/testdata/structtest/struct_test.c +++ b/testdata/structtest/struct_test.c @@ -280,7 +280,7 @@ unsigned int Array2UnsignedShorts(struct Array2UnsignedShort a) { } struct Array4Chars { - char a[4]; + signed char a[4]; }; int Array4Chars(struct Array4Chars a) { @@ -374,3 +374,20 @@ struct FourInt32s { int32_t FourInt32s(struct FourInt32s s) { return s.f0 + s.f1 + s.f2 + s.f3; } + +struct PointerWrapper { + void* ctx; +}; + +uintptr_t ExtractPointer(struct PointerWrapper wrapper) { + return (uintptr_t)wrapper.ctx; +} + +struct TwoPointers { + void* ptr1; + void* ptr2; +}; + +uintptr_t AddPointers(struct TwoPointers wrapper) { + return (uintptr_t)wrapper.ptr1 + (uintptr_t)wrapper.ptr2; +} diff --git a/testdata/structtest/structreturn_test.c b/testdata/structtest/structreturn_test.c index c55a0f97..b539ed37 100644 --- a/testdata/structtest/structreturn_test.c +++ b/testdata/structtest/structreturn_test.c @@ -227,6 +227,41 @@ struct Mixed4 ReturnMixed4(double a, uint32_t b, float c) { return s; } +struct Mixed5{ + int64_t *a; + int32_t b; + float c; + int32_t d; +}; + +struct Mixed5 ReturnMixed5(int64_t *a, int32_t b, float c, int32_t d) { + struct Mixed5 s = {a, b, c, d}; + return s; +} + +struct SmallBool{ + _Bool a; + int32_t b; + int64_t c; +}; + +struct SmallBool ReturnSmallBool(_Bool a, int32_t b, int64_t c) { + struct SmallBool s = {a, b, c}; + return s; +} + +struct LargeBool{ + _Bool a; + int32_t b; + int64_t c; + int64_t d; +}; + +struct LargeBool ReturnLargeBool(_Bool a, int32_t b, int64_t c, int64_t d) { + struct LargeBool s = {a, b, c, d}; + return s; +} + struct Ptr1{ int64_t *a; void *b;