Source code for transonic.backends.numba

"""Numba backend
================

Internal API
------------

.. autoclass:: SubBackendJITNumba
   :members:
   :private-members:

.. autoclass:: NumbaBackend
   :members:
   :private-members:

"""

from typing import Optional

from transonic.analyses.extast import parse, unparse, CommentLine, gast
from transonic.util import format_str

from .py import PythonBackend, SubBackendJITPython


def add_numba_comments(code):
    """Add Numba code in Python comments"""
    mod = parse(code)
    new_body = [CommentLine("# __protected__ from numba import njit")]

    for node in mod.body:
        if isinstance(node, gast.FunctionDef):
            new_body.append(
                CommentLine("# __protected__ @njit(cache=True, fastmath=True)")
            )
        new_body.append(node)

    mod.body = new_body
    return format_str(unparse(mod))


[docs]class SubBackendJITNumba(SubBackendJITPython): def make_backend_source(self, info_analysis, func, path_backend): src, has_to_write = super().make_backend_source( info_analysis, func, path_backend ) if not src: return src, has_to_write return add_numba_comments(src), has_to_write
[docs]class NumbaBackend(PythonBackend): """Main class for the Numba backend""" backend_name = "numba" _SubBackendJIT = SubBackendJITNumba def compile_extension( self, path_backend, name_ext_file=None, native=False, xsimd=False, openmp=False, str_accelerator_flags: Optional[str] = None, parallel=True, force=True, ): if name_ext_file is None: name_ext_file = self.name_ext_from_path_backend(path_backend) with open(path_backend) as file: source = file.read() source = source.replace("# __protected__ ", "") with open(path_backend.with_name(name_ext_file), "w") as file: file.write(format_str(source)) compiling = False process = None return compiling, process
[docs] def _make_backend_code(self, path_py, analysis, **kwargs): """Create a backend code from a Python file""" code, codes_ext, header = super()._make_backend_code(path_py, analysis) if not code: return code, codes_ext, header code = add_numba_comments(code) for_meson = kwargs.get("for_meson", False) if for_meson: code = format_str(code.replace("# __protected__ ", "")) return code, codes_ext, header