numba extension support (clifford.numba
)¶
New in version 1.4.0.
This module provides numba
extension types MultiVectorType
and
LayoutType
corresponding to MultiVector
and
Layout
.
You do not need to import this module to take advantage of these types; they
are needed only directly when writing numba overloads via
numba.extending.overload()
and similar.
As a simple example, the following code defines a vectorized up()
function
for CGA
from clifford.g3c import *
@numba.njit
def jit_up(x):
return eo + x + 0.5*abs(x)**2*einf
assert up(e1) == jit_up(e1)
Note that a rough equivalent to this particular function is provided elsewhere
as clifford.tools.g3c.fast_up()
.
Supported operations¶
The following list of operations are supported in a jitted context:
A limited version of the constructor
MultiVector(layout, value)
, and the aliaslayout.MultiVector()
.MultiVector.value
MultiVector.layout
Arithmetic:
MultiVector.__pow__()
MultiVector.__truediv__()
MultiVector.__pos__()
MultiVector.__neg__()
MultiVector.__abs__()
Performance considerations¶
While the resulted jitted code is typically faster, there are two main
performance issues to consider. The first is the startup time of @jit
ing.
This can be quite substantial, although can be somewhat alleviated by using
the cache=True
argument to numba.jit()
.
The second is the time taken for numba to find the appropriate dispatch loop
given the Python types of its arguments, which adds overhead to every call.
clifford
tries as hard as possible to reduce this second overhead, by
using the undocumented _numba_type_
attribute and keeping our own optimized
cache instead of going through the recommended
@numba.extending.typeof_impl.register(LayoutType)
mechanism.
However, this overhead can still be slow compared to elementary operations.
The following code is significantly impacted by this overhead:
from clifford.g3c import *
import numba
@numba.njit
def mul(a, b):
return a * b
# 286 ms, ignoring jitting time
x = e1
for i in range(10000):
x = mul(x, x + e2)
as each iteration of the loop pays it again. The overhead can be avoided by jitting the entire loop:
from clifford.g3c import *
import numba
@numba.njit
def mul(a, b):
return a * b
@numba.njit
def the_loop(x):
for i in range(1000):
x = mul(x, x + e1)
return x
# 2.4 ms, ignoring jitting time
x = the_loop(eq)