Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func RegisterLibFunc(fptr any, handle uintptr, name string) {
// int64 <=> int64_t
// float32 <=> float
// float64 <=> double
// struct <=> struct (WIP - darwin only)
// struct <=> struct (darwin amd64/arm64, linux amd64/arm64)
// func <=> C function
// unsafe.Pointer, *T <=> void*
// []T => void*
Expand Down Expand Up @@ -169,9 +169,7 @@ func RegisterFunc(fptr any, cfn uintptr) {
stack++
}
case reflect.Struct:
if runtime.GOOS != "darwin" || (runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64") {
panic("purego: struct arguments are only supported on darwin amd64 & arm64")
}
ensureStructSupportedForRegisterFunc()
if arg.Size() == 0 {
continue
}
Expand All @@ -190,9 +188,7 @@ func RegisterFunc(fptr any, cfn uintptr) {
}
}
if ty.NumOut() == 1 && ty.Out(0).Kind() == reflect.Struct {
if runtime.GOOS != "darwin" {
panic("purego: struct return values only supported on darwin arm64 & amd64")
}
ensureStructSupportedForRegisterFunc()
outType := ty.Out(0)
checkStructFieldsSupported(outType)
if runtime.GOARCH == "amd64" && outType.Size() > maxRegAllocStructSize {
Expand Down Expand Up @@ -465,13 +461,23 @@ func checkStructFieldsSupported(ty reflect.Type) {
switch f.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Uintptr, reflect.Ptr, reflect.UnsafePointer, reflect.Float64, reflect.Float32:
reflect.Uintptr, reflect.Ptr, reflect.UnsafePointer, reflect.Float64, reflect.Float32,
reflect.Bool:
default:
panic(fmt.Sprintf("purego: struct field type %s is not supported", f))
}
}
}

func ensureStructSupportedForRegisterFunc() {
if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" {
panic("purego: struct arguments are only supported on amd64 and arm64")
}
if runtime.GOOS != "darwin" && runtime.GOOS != "linux" {
panic("purego: struct arguments are only supported on darwin and linux")
}
}

func roundUpTo8(val uintptr) uintptr {
return (val + 7) &^ 7
}
Expand Down
9 changes: 5 additions & 4 deletions struct_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,10 @@ func tryPlaceRegister(v reflect.Value, addFloat func(uintptr), addInt func(uintp
}
shift += 8
class |= _INTEGER
case reflect.Pointer:
ok = false
return
case reflect.Pointer, reflect.UnsafePointer:
val = uint64(f.Pointer())
shift = 64
class = _INTEGER
case reflect.Int8:
val |= uint64(f.Int()&0xFF) << shift
shift += 8
Expand Down Expand Up @@ -241,7 +242,7 @@ func placeStack(v reflect.Value, addStack func(uintptr)) {
for i := 0; i < v.Type().NumField(); i++ {
f := v.Field(i)
switch f.Kind() {
case reflect.Pointer:
case reflect.Pointer, reflect.UnsafePointer:
addStack(f.Pointer())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
addStack(uintptr(f.Int()))
Expand Down
2 changes: 1 addition & 1 deletion struct_arm64.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func placeRegisters(v reflect.Value, addFloat func(uintptr), addInt func(uintptr
shift = 0
flushed = true
class = _NO_CLASS
case reflect.Ptr:
case reflect.Ptr, reflect.UnsafePointer:
addInt(f.Pointer())
shift = 0
flushed = true
Expand Down
80 changes: 77 additions & 3 deletions struct_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: 2024 The Ebitengine Authors

//go:build darwin && (arm64 || amd64)
//go:build (darwin || linux) && (amd64 || arm64)

package purego_test

Expand Down Expand Up @@ -373,8 +373,8 @@ func TestRegisterFunc_structArgs(t *testing.T) {
}
var Array4CharsFn func(chars Array4Chars) int32
purego.RegisterLibFunc(&Array4CharsFn, lib, "Array4Chars")
const expectedSum = 1 + 2 + 4 + 8
if ret := Array4CharsFn(Array4Chars{a: [...]int8{1, 2, 4, 8}}); ret != expectedSum {
const expectedSum = 123
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional:

const expectedSum = 10 + 20 + 30 + 63

if ret := Array4CharsFn(Array4Chars{a: [...]int8{10, 20, 30, 63}}); ret != expectedSum {
t.Fatalf("Array4CharsFn returned %d wanted %d", ret, expectedSum)
}
}
Expand Down Expand Up @@ -499,6 +499,37 @@ func TestRegisterFunc_structArgs(t *testing.T) {
t.Fatalf("FourInt32s returned %d wanted %d", result, want)
}
}
{
type PointerWrapper struct {
ctx unsafe.Pointer
}
var ExtractPointer func(wrapper PointerWrapper) uintptr
purego.RegisterLibFunc(&ExtractPointer, lib, "ExtractPointer")

var v int
ptr := unsafe.Pointer(&v)
expected := uintptr(ptr)
result := ExtractPointer(PointerWrapper{ctx: ptr})
if result != expected {
t.Fatalf("ExtractPointer returned %#x wanted %#x", result, expected)
}
}
{
type TwoPointers struct {
ptr1, ptr2 unsafe.Pointer
}
var AddPointers func(wrapper TwoPointers) uintptr
purego.RegisterLibFunc(&AddPointers, lib, "AddPointers")

var v1, v2 int
ptr1 := unsafe.Pointer(&v1)
ptr2 := unsafe.Pointer(&v2)
expected := uintptr(ptr1) + uintptr(ptr2)
result := AddPointers(TwoPointers{ptr1, ptr2})
if result != expected {
t.Fatalf("AddPointers returned %#x wanted %#x", result, expected)
}
}
}

func TestRegisterFunc_structReturns(t *testing.T) {
Expand Down Expand Up @@ -757,6 +788,49 @@ func TestRegisterFunc_structReturns(t *testing.T) {
t.Fatalf("ReturnMixed4 returned %+v wanted %+v", ret, expected)
}
}
{
type Mixed5 struct {
a *int64
b int32
c float32
d int32
}
var ReturnMixed5 func(a *int64, b int32, c float32, d int32) Mixed5
purego.RegisterLibFunc(&ReturnMixed5, lib, "ReturnMixed5")
ptr := new(int64)
expected := Mixed5{ptr, 1, 7.2, 9}
if ret := ReturnMixed5(ptr, 1, 7.2, 9); ret != expected {
t.Fatalf("ReturnMixed5 returned %+v wanted %+v", ret, expected)
}
runtime.KeepAlive(ptr)
}
{
type SmallBool struct {
a bool
b int32
c int64
}
var ReturnSmallBool func(a bool, b int32, c int64) SmallBool
purego.RegisterLibFunc(&ReturnSmallBool, lib, "ReturnSmallBool")
expected := SmallBool{true, 42, 123456789}
if ret := ReturnSmallBool(true, 42, 123456789); ret != expected {
t.Fatalf("ReturnSmallBool returned %+v wanted %+v", ret, expected)
}
}
{
type LargeBool struct {
a bool
b int32
c int64
d int64
}
var ReturnLargeBool func(a bool, b int32, c int64, d int64) LargeBool
purego.RegisterLibFunc(&ReturnLargeBool, lib, "ReturnLargeBool")
expected := LargeBool{false, -99, 987654321, 111222333444}
if ret := ReturnLargeBool(false, -99, 987654321, 111222333444); ret != expected {
t.Fatalf("ReturnLargeBool returned %+v wanted %+v", ret, expected)
}
}
{
type Ptr1 struct {
a *int64
Expand Down
19 changes: 18 additions & 1 deletion testdata/structtest/struct_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ unsigned int Array2UnsignedShorts(struct Array2UnsignedShort a) {
}

struct Array4Chars {
char a[4];
signed char a[4];
};

int Array4Chars(struct Array4Chars a) {
Expand Down Expand Up @@ -374,3 +374,20 @@ struct FourInt32s {
int32_t FourInt32s(struct FourInt32s s) {
return s.f0 + s.f1 + s.f2 + s.f3;
}

struct PointerWrapper {
void* ctx;
};

uintptr_t ExtractPointer(struct PointerWrapper wrapper) {
return (uintptr_t)wrapper.ctx;
}

struct TwoPointers {
void* ptr1;
void* ptr2;
};

uintptr_t AddPointers(struct TwoPointers wrapper) {
return (uintptr_t)wrapper.ptr1 + (uintptr_t)wrapper.ptr2;
}
35 changes: 35 additions & 0 deletions testdata/structtest/structreturn_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,41 @@ struct Mixed4 ReturnMixed4(double a, uint32_t b, float c) {
return s;
}

struct Mixed5{
int64_t *a;
int32_t b;
float c;
int32_t d;
};

struct Mixed5 ReturnMixed5(int64_t *a, int32_t b, float c, int32_t d) {
struct Mixed5 s = {a, b, c, d};
return s;
}

struct SmallBool{
_Bool a;
int32_t b;
int64_t c;
};

struct SmallBool ReturnSmallBool(_Bool a, int32_t b, int64_t c) {
struct SmallBool s = {a, b, c};
return s;
}

struct LargeBool{
_Bool a;
int32_t b;
int64_t c;
int64_t d;
};

struct LargeBool ReturnLargeBool(_Bool a, int32_t b, int64_t c, int64_t d) {
struct LargeBool s = {a, b, c, d};
return s;
}

struct Ptr1{
int64_t *a;
void *b;
Expand Down