Source code for transonic.backends.jax

"""Jax backend
================

Internal API
------------

.. autoclass:: SubBackendJITJax
   :members:
   :private-members:

.. autoclass:: JaxBackend
   :members:
   :private-members:

"""

from typing import Optional

from transonic.analyses.extast import CommentLine, gast, parse, unparse
from transonic.util import format_str

from .py import PythonBackend, SubBackendJITPython


def add_jax_comments(code):
    """Add Jax code in Python comments"""
    mod = parse(code)
    new_body = [CommentLine("# __protected__ from jax import jit")]

    for node in mod.body:
        # Replace `import numpy` -> `import jax.numpy as numpy`
        # Replace `import numpy as np` -> `import jax.numpy as np`
        if isinstance(node, gast.Import):
            if (alias := node.names[0]).name == "numpy":
                g_alias = gast.alias(
                    name="jax.numpy", asname=alias.asname or alias.name
                )
                node = gast.Import([g_alias])

        # Replace `from numpy import eye` -> `from jax.numpy import eye`
        elif isinstance(node, gast.ImportFrom) and node.module == "numpy":
            node.module = "jax.numpy"

        # Add JIT decorator
        if (
            isinstance(node, gast.FunctionDef)
            and node.name
            not in (
                "arguments_blocks",
                "__transonic__",
            )
            and not node.name.startswith("__code_new_method__")
        ):
            new_body.append(CommentLine("# __protected__ @jit"))
        new_body.append(node)

    mod.body = new_body
    return format_str(unparse(mod))


[docs]class SubBackendJITJax(SubBackendJITPython): def make_backend_source(self, info_analysis, func, path_backend): src, has_to_write = super().make_backend_source( info_analysis, func, path_backend ) if not src: return src, has_to_write return add_jax_comments(src), has_to_write
[docs]class JaxBackend(PythonBackend): """Main class for the Jax backend See https://github.com/google/jax """ backend_name = "jax" _SubBackendJIT = SubBackendJITJax def compile_extension( self, path_backend, name_ext_file=None, native=False, xsimd=False, openmp=False, str_accelerator_flags: Optional[str] = None, parallel=True, force=True, ): if name_ext_file is None: name_ext_file = self.name_ext_from_path_backend(path_backend) with open(path_backend) as file: source = file.read() source = source.replace("# __protected__ ", "") with open(path_backend.with_name(name_ext_file), "w") as file: file.write(format_str(source)) compiling = False process = None return compiling, process
[docs] def _make_backend_code(self, path_py, analysis, **kwargs): """Create a backend code from a Python file""" code, codes_ext, header = super()._make_backend_code(path_py, analysis) if not code: return code, codes_ext, header code = add_jax_comments(code) for_meson = kwargs.get("for_meson", False) if for_meson: code = format_str(code.replace("# __protected__ ", "")) return code, codes_ext, header