Skip to content
This repository was archived by the owner on Feb 2, 2024. It is now read-only.

Commit 08f9c1e

Browse files
author
Ivan Butygin
committed
Function pointer example
1 parent c59fd1f commit 08f9c1e

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

sdc/_concurrent_hash.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,16 @@ void deleteiter_int_hashmap(void* obj)
9191
delete static_cast<int_hashmap_iters*>(obj);
9292
}
9393

94+
using funcptr_t = int32_t(*)(int32_t,int32_t,int32_t);
95+
int32_t test_funcptr(funcptr_t func, int32_t a, int32_t b)
96+
{
97+
int32_t res = 0;
98+
for (int i = 0; i < 10; ++i)
99+
{
100+
res += func(a, b, i);
101+
}
102+
return res;
103+
}
94104

95105
PyMODINIT_FUNC PyInit_hconcurrent_hash()
96106
{
@@ -118,6 +128,8 @@ PyMODINIT_FUNC PyInit_hconcurrent_hash()
118128
REGISTER(iterkey_int_hashmap)
119129
REGISTER(iterval_int_hashmap)
120130
REGISTER(deleteiter_int_hashmap)
131+
132+
REGISTER(test_funcptr)
121133
#undef REGISTER
122134
return m;
123135
}

sdc/concurrent_hash.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
import sdc
2929

3030
from numba import types, typing, generated_jit
31-
from numba.extending import models, register_model
32-
from numba.extending import lower_builtin, overload_method, overload, intrinsic
31+
from numba.extending import lower_builtin, overload_method, overload, intrinsic, register_jitable
3332

3433
from llvmlite import ir as lir
3534
import llvmlite.binding as ll
@@ -45,6 +44,8 @@
4544
ll.add_symbol('iterval_int_hashmap', hconcurrent_hash.iterval_int_hashmap)
4645
ll.add_symbol('deleteiter_int_hashmap', hconcurrent_hash.deleteiter_int_hashmap)
4746

47+
ll.add_symbol('test_funcptr', hconcurrent_hash.test_funcptr)
48+
4849
_create_int_hashmap = types.ExternalFunction("create_int_hashmap",
4950
types.voidptr())
5051
_delete_int_hashmap = types.ExternalFunction("delete_int_hashmap",
@@ -65,6 +66,9 @@
6566
_deleteiter_int_hashmap = types.ExternalFunction("deleteiter_int_hashmap",
6667
types.void(types.voidptr))
6768

69+
_test_funcptr = types.ExternalFunction("test_funcptr",
70+
types.int32(types.voidptr,types.int32,types.int32))
71+
6872

6973
def create_int_hashmap():
7074
pass
@@ -102,6 +106,10 @@ def deleteiter_int_hashmap():
102106
pass
103107

104108

109+
def test_funcptr():
110+
pass
111+
112+
105113
@overload(create_int_hashmap)
106114
def create_int_hashmap_overload():
107115
return lambda: _create_int_hashmap()
@@ -145,3 +153,18 @@ def iterval_int_hashmap_overload(h):
145153
@overload(deleteiter_int_hashmap)
146154
def deleteiter_int_hashmap_overload(h):
147155
return lambda h: _deleteiter_int_hashmap(h)
156+
157+
158+
159+
@register_jitable
160+
def sink(*args):
161+
args[0]
162+
163+
@overload(test_funcptr)
164+
def test_funcptr_overload(a,b,c):
165+
def func(a,b,c):
166+
res = _test_funcptr(a,b,c)
167+
sink(a,b,c)
168+
return res
169+
170+
return func

sdc/tests/test_dataframe.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,14 @@ def test_impl():
17941794
def test_tbb(self):
17951795
import sdc.concurrent_hash
17961796

1797-
def test_impl():
1797+
@numba.cfunc("int32(int32, int32, int32)")
1798+
def callback(x, y, z):
1799+
return x + y + z
1800+
1801+
global funcptr
1802+
funcptr = callback.address
1803+
1804+
def test_impl1():
17981805
h = sdc.concurrent_hash.create_int_hashmap()
17991806

18001807
sdc.concurrent_hash.addelem_int_hashmap(h, 1, 2)
@@ -1814,8 +1821,15 @@ def test_impl():
18141821
sdc.concurrent_hash.deleteiter_int_hashmap(it)
18151822
sdc.concurrent_hash.delete_int_hashmap(h)
18161823

1817-
hpat_func = self.jit(test_impl)
1818-
hpat_func()
1824+
hpat_func1 = self.jit(test_impl1)
1825+
hpat_func1()
1826+
1827+
def test_impl2():
1828+
r = sdc.concurrent_hash.test_funcptr(funcptr, 2, 3)
1829+
print('res', r)
1830+
1831+
hpat_func2 = self.jit(test_impl2)
1832+
hpat_func2()
18191833

18201834

18211835
if __name__ == "__main__":

0 commit comments

Comments
 (0)