"""Internal utilities
=====================
Public API
----------
.. autofunction:: set_compile_at_import
Internal API
------------
.. autofunction:: find_module_name_from_path
.. autofunction:: get_module_name
.. autofunction:: modification_date
.. autofunction:: has_to_build
.. autofunction:: get_source_without_decorator
.. autoclass:: TypeHintRemover
:members:
:private-members:
.. autofunction:: strip_typehints
.. autofunction:: get_ipython_input
.. autofunction:: get_info_from_ipython
.. autofunction:: has_to_compile_at_import
.. autofunction:: import_from_path
.. autofunction:: query_yes_no
.. autofunction:: clear_cached_extensions
.. autofunction:: is_method
.. autofunction:: has_to_write
.. autofunction:: write_if_has_to_write
"""
import os
import sys
import inspect
import re
from pathlib import Path
import importlib.util
import shutil
from textwrap import dedent
from typing import Callable
import gast as ast
from transonic.config import backend_default
try:
# since black is still beta (in 03/2019), we cannot impose a version :-(
import black
except ImportError:
import autopep8
def format_str(src_contents):
try:
return autopep8.fix_code(src_contents)
except AttributeError:
# workaround https://github.com/hhatto/autopep8/issues/689
return src_contents
else:
try:
_mode = black.FileMode(line_length=82)
except TypeError:
def format_str(src_contents: str):
try:
return black.format_str(src_contents, line_length=82)
except black.InvalidInput:
print("black.InvalidInput\n" + src_contents)
raise
else:
def format_str(src_contents: str):
try:
return black.format_str(src_contents, mode=_mode)
except black.InvalidInput:
print("black.InvalidInput\n" + src_contents)
raise
try:
from IPython.core.getipython import get_ipython
except ImportError:
pass
from transonic import __version__
from transonic.analyses import extast
from transonic.compiler import (
ext_suffix,
make_hex,
modification_date,
has_to_build,
)
from transonic.config import path_root, strtobool
__all__ = ["modification_date", "has_to_build", "path_root"]
def can_import_accelerator(backend: str = backend_default):
if backend == "pythran":
try:
import pythran
except ImportError:
return False
elif backend == "cython":
try:
import cython
except ImportError:
return False
elif backend == "numba":
try:
import numba
except ImportError:
return False
elif backend == "python":
return True
else:
raise NotImplementedError
return True
def print_versions(accelerators=None):
print(f"Transonic {__version__}")
if accelerators is None or "pythran" in accelerators:
try:
import pythran
except ImportError:
print("Pythran not importable")
else:
print(f"Pythran {pythran.__version__}")
if accelerators is None or "numba" in accelerators:
try:
import numba
except ImportError:
print("Numba not importable")
else:
print(f"Numba {numba.__version__}")
if accelerators is None or "cython" in accelerators:
try:
import Cython
except ImportError:
print("Cython not importable")
else:
print(f"Cython {Cython.__version__}")
[docs]def find_module_name_from_path(path_py: Path):
"""Find the module name from the path of a Python file
It is done by looking to ``sys.path`` to see how the module can be imported.
"""
path_py = Path(path_py)
cwd = Path.cwd()
path = path_py.absolute().parent
module_name = path_py.stem
# special case for jit_classes
special_dir = "__jit_class__"
if special_dir in path.parts:
tmp = [special_dir]
name_pack = ".".join(path.parts[path.parts.index(special_dir) + 1 :])
if name_pack:
tmp.append(name_pack)
tmp.append(module_name)
return ".".join(tmp)
while path.parents:
if path == cwd or str(path) in sys.path:
return module_name
module_name = path.name + "." + module_name
path = path.parent
return path_py.stem
def _get_pathfile_from_frame(frame):
return frame.f_code.co_filename
[docs]def get_module_name(frame):
"""Get the full module name"""
try:
module = inspect.getmodule(frame)
except AttributeError:
# bug when runpy.run_path is called with pathlib.Path
return _get_pathfile_from_frame(frame)
pathfile = None
if module is not None:
module_name = module.__name__
if module_name in ("__main__", "<run_path>"):
pathfile = _get_pathfile_from_frame(frame)
else:
pathfile = _get_pathfile_from_frame(frame)
if pathfile is not None:
module_name = find_module_name_from_path(Path(pathfile))
if module_name is None:
# ipython ?
src, module_name = get_info_from_ipython()
return module_name
def get_name_calling_module():
return get_module_name(get_frame(1))
def get_frame(depth=1):
"""Return a frame object from the call stack.
This should be equivalent to ``inspect.currentframe().f_back`` with ``depth``
times ``.f_back``. For example, ``depth = 1`` gives the caller frame.
"""
try:
return sys._getframe(depth + 1)
except AttributeError:
# we might want to implement this with another function
# (``inspect.currentframe`` or ``inspect.stack()``)
# for alternative Python implementations
raise
except ValueError:
print([frame_info[1] for frame_info in inspect.stack()])
raise
[docs]def get_source_without_decorator(func: Callable):
"""Get the source of a function without its decorator"""
src = inspect.getsource(func)
src = dedent(src)
return strip_typehints(re.sub(r"@.*?\sdef\s", "def ", src))
[docs]class TypeHintRemover(ast.NodeTransformer):
"""Strip the type hints
from https://stackoverflow.com/a/42734810/1779806
"""
def visit_FunctionDef(self, fdef):
# remove the return type defintion
fdef.returns = None
# remove all argument annotations
if fdef.args.args:
for arg in fdef.args.args:
arg.annotation = None
body = []
for node in fdef.body:
if isinstance(node, ast.AnnAssign):
if node.value is None:
continue
node = ast.Assign(
targets=[node.target], value=node.value, type_comment=None
)
body.append(node)
fdef.body = body
return fdef
[docs]def strip_typehints(source):
"""Strip the type hints from a function"""
source = format_str(source)
# parse the source code into an AST
parsed_source = ast.parse(source)
# remove all type annotations, function return type definitions
# and import statements from 'typing'
transformed = TypeHintRemover().visit(parsed_source)
# convert the AST back to source code
striped_code = extast.unparse(transformed)
return striped_code
def make_code_from_fdef_node(fdef):
transformed = TypeHintRemover().visit(fdef)
# convert the AST back to source code
code = extast.unparse(transformed)
return format_str(code)
[docs]def get_info_from_ipython():
"""Get the input code and a "filename" when called from IPython"""
src = get_ipython_input()
hex_input = make_hex(src)
dummy_filename = "__ipython__" + hex_input
return src, dummy_filename
_PYTHRANIZE_AT_IMPORT = None
[docs]def set_compile_at_import(value=True):
"""Control the "compile_at_import" mode"""
global _PYTHRANIZE_AT_IMPORT
_PYTHRANIZE_AT_IMPORT = value
[docs]def has_to_compile_at_import():
"""Check if transonic has to pythranize at import time"""
if _PYTHRANIZE_AT_IMPORT is not None:
return _PYTHRANIZE_AT_IMPORT
return "TRANSONIC_COMPILE_AT_IMPORT" in os.environ
[docs]def import_from_path(path: Path, module_name: str):
"""Import a .py file or an extension from its path"""
if not path.exists():
raise ImportError(
f"File {path} does not exist. "
f"[path.name for path in path.parent.glob('*')]:\n{[path.name for path in path.parent.glob('*')]}\n"
)
if "." in module_name:
package_name, mod_name = module_name.rsplit(".", 1)
name_file = path.name.split(".", 1)[0]
if mod_name != name_file:
module_name = ".".join((package_name, name_file))
else:
module_name = path.stem
if module_name in sys.modules:
module = sys.modules[module_name]
if module.__file__.endswith(ext_suffix) and Path(module.__file__) == path:
return module
spec = importlib.util.spec_from_file_location(module_name, path)
# for potential "local imports" in the module
sys.path.insert(0, str(path.parent))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# clean sys.path
sys.path.pop(0)
# fix bug Numba
sys.modules[module_name] = module
return module
[docs]def query_yes_no(question: str, default: str = None, force: bool = False):
"""User yes or no query"""
if force:
return True
if default is None:
end = "(y/n)"
default = ""
elif default == "y":
end = "([y]/n)"
elif default == "n":
end = "(y/[n])"
print(f"{question} {end}")
while True:
answer = input()
if answer == "":
answer = default
try:
return strtobool(answer)
except ValueError:
print('Please respond with "y" or "n".')
[docs]def clear_cached_extensions(module_name: str, force: bool, backend: str):
"""Delete the cached extensions related to a module"""
from transonic.backends import backends
from transonic import mpi
backend = backends[backend]
path_jit = mpi.Path(backend.jit.path_base)
if module_name.endswith(".py"):
module_name = module_name[:-3]
if os.path.sep not in module_name:
relative_path = module_name.replace(".", os.path.sep)
else:
relative_path = module_name
path_pythran_dir_jit = path_jit / relative_path
relative_path = Path(relative_path)
path_pythran = relative_path.parent / (
"__{backend.name}__/" + relative_path.name + ".py"
)
path_ext = path_pythran.with_name(
backend.name_ext_from_path_backend(path_pythran)
)
if not path_pythran_dir_jit.exists() and not (
path_pythran.exists() or path_ext.exists()
):
print(
f"Not able to find cached extensions corresponding to {module_name}"
)
print("Nothing to do! ✨ 🍰 ✨.")
return
if path_pythran_dir_jit.exists() and query_yes_no(
f"Do you confirm that you want to delete the cached files for {module_name}",
default="y",
force=force,
):
print(f"Remove directory {path_pythran_dir_jit}")
shutil.rmtree(path_pythran_dir_jit)
if path_pythran.exists() or path_ext.exists():
if query_yes_no(
f"Do you confirm that you want to delete the AOT cache for {module_name}",
default="y",
force=force,
):
for path in (path_pythran, path_ext):
if path.exists():
path.unlink()
[docs]def is_method(func):
"""Determine wether a function is going to be used as a method"""
signature = inspect.signature(func)
try:
answer = next(iter(signature.parameters.keys())) == "self"
except StopIteration:
answer = False
return answer
[docs]def has_to_write(path_file: Path, new_code: str):
"""Check if a file exists and contains a code"""
if path_file.exists():
with open(path_file, "r") as file:
source_code = file.read()
if new_code == source_code:
return False
else:
return True
else:
return True
[docs]def write_if_has_to_write(
path_file: Path, new_code: str, logger=None, force=False
):
"""Write a file if it doesn't exist or doesn't contain a particular code"""
written = False
if has_to_write(path_file, new_code) or force:
written = True
path_file.parent.mkdir(exist_ok=True, parents=True)
with open(path_file, "w") as file:
file.write(new_code)
if logger:
logger(f"{path_file} written")
return written
def timeit(stmt="pass", setup="pass", total_duration=2, globals=None):
"""timeit with a approximate total_duration"""
from timeit import Timer
timer = Timer(stmt, setup, globals=globals)
timer.timeit(10)
duration1 = timer.timeit(10) / 10
number = max(1, int(round(total_duration / duration1)))
if number > 200:
repeat = int(round(number / 100))
number = int(round(number / repeat))
else:
repeat = 1
if number * duration1 > 2 * total_duration:
raise RuntimeError("number * duration1 > 2 * total_duration")
duration = min(timer.repeat(repeat=repeat, number=number))
return duration / number
def timeit_verbose(
stmt,
setup="pass",
total_duration=2,
globals=None,
norm=None,
name=None,
print_time=False,
max_length_name=33,
):
result = timeit(
stmt, setup=setup, total_duration=total_duration, globals=globals
)
if norm is None:
norm = result
norm_given = False
else:
norm_given = True
if name is None:
name = stmt.split("(")[0]
fmt_name = f"{{:{max_length_name}s}}"
name = fmt_name.format(name)
if print_time:
raw_time = f" = {result:7.3g} s"
else:
raw_time = ""
print(f"{name}: {result/norm:5.3g} * norm{raw_time}")
if not norm_given and not print_time:
print(f"norm = {norm:5.3g} s")
return result