"""Create Pythran signatures from type hints
============================================
User API
--------
.. autoclass:: Type
:members:
.. autoclass:: NDim
:members:
.. autoclass:: Array
:members:
:private-members:
.. autoclass:: List
:members:
:private-members:
.. autoclass:: Tuple
:members:
:private-members:
.. autoclass:: Dict
:members:
:private-members:
.. autoclass:: Set
:members:
:private-members:
.. autoclass:: Union
:members:
:private-members:
.. autofunction:: str2type
.. autofunction:: typeof
.. autofunction:: const
Internal API
------------
.. autoclass:: TemplateVar
:members:
:private-members:
.. autoclass:: ArrayMeta
:members:
:private-members:
.. autoclass:: ListMeta
:members:
:private-members:
.. autoclass:: DictMeta
:members:
:private-members:
.. autofunction:: format_type_as_backend_type
.. autoclass:: ConstType
:members:
:private-members:
"""
import re
from enum import Enum, auto
import itertools
import numpy as np
from transonic.util import get_name_calling_module
names_template_variables = {}
class FusedType:
def is_fused_type(self):
raise NotImplementedError
def get_all_formatted_backend_types(self, type_formatter):
template_params = self.get_template_parameters()
values_template_parameters = {
param.__name__: param.values for param in template_params
}
names = tuple(values_template_parameters.keys())
formatted_types = []
for set_types in itertools.product(*values_template_parameters.values()):
template_variables = dict(zip(names, set_types))
formatted_types.append(
format_type_as_backend_type(
self, type_formatter, **template_variables
)
)
return formatted_types
[docs]class TemplateVar:
"""Base class for template variables
>>> T = TemplateVar("T")
>>> T = TemplateVar("T", int, float)
>>> T = TemplateVar()
Traceback (most recent call last):
...
ValueError
>>> T = TemplateVar(1)
Traceback (most recent call last):
...
TypeError: (1,) [False]
"""
_type_values = type
_letter = "T"
def get_template_parameters(self):
return (self,)
def __init__(self, *args, name_calling_module=None):
if not args:
raise ValueError
if name_calling_module is None:
name_calling_module = get_name_calling_module()
if name_calling_module not in names_template_variables:
names_template_variables[name_calling_module] = {}
names_variables = names_template_variables[name_calling_module]
if type(self) not in names_variables:
names_variables[type(self)] = set()
names_already_used = names_variables[type(self)]
if self._is_correct_for_name(args[0]):
self.__name__ = args[0]
args = args[1:]
else:
index_var = len(names_already_used)
while self._letter + str(index_var) in names_already_used:
index_var += 1
self.__name__ = self._letter + str(index_var)
self.values = args
names_already_used.add(self.__name__)
self._check_type_values()
def _is_correct_for_name(self, arg):
return isinstance(arg, str)
def _check_type_values(self):
if not all(isinstance(value, self._type_values) for value in self.values):
raise TypeError(
f"{self.values} "
f"{[isinstance(value, self._type_values) for value in self.values]}"
)
def has_multiple_values(self):
return len(self.values) > 1
[docs]class Type(TemplateVar, FusedType):
"""Template variable representing the dtype of an array.
As a user, it is useful only for fused types.
>>> Type(int, float)
Type(int, float)
"""
def __repr__(self):
repr_values = []
for value in self.values:
if hasattr(value, "__name__"):
repr_values.append(value.__name__)
else:
repr_values.append(repr(value))
return f"Type({', '.join(repr_values)})"
def format_as_backend_type(self, backend_type_formatter, **kwargs):
dtype = None
for key, value in kwargs.items():
if key == self.__name__:
dtype = value
break
if dtype is None:
raise ValueError
return dtype.__name__
def is_fused_type(self):
return len(self.values) > 1
def short_repr(self):
long_repr = repr(self)
replaced_by = {"(": "I", ")": "I", ", ": "_"}
for replaced, replacer in replaced_by.items():
long_repr = long_repr.replace(replaced, replacer)
return long_repr
[docs]class NDim(TemplateVar):
"""Template variable representing the number of dimension of an array.
As a user, it is useful only for fused types.
>>> N = NDim(1, 2)
>>> N1 = N + 1
"""
_type_values = int
_letter = "N"
def __init__(self, *args, shift=0, name_calling_module=None):
if name_calling_module is None:
name_calling_module = get_name_calling_module()
super().__init__(*args, name_calling_module=name_calling_module)
self.shift = shift
def __repr__(self):
if len(self.values) == 1:
name = f'"{self.values[0]}d"'
else:
name = f"NDim({', '.join(repr(v) for v in self.values)})"
if self.shift == 0:
return name
elif self.shift < 0:
return name + f" - {-self.shift}"
elif self.shift > 0:
return name + f" + {self.shift}"
else:
raise RuntimeError
def __add__(self, number):
name_calling_module = get_name_calling_module()
return type(self)(
self.__name__,
*self.values,
shift=number,
name_calling_module=name_calling_module,
)
def __sub__(self, number):
name_calling_module = get_name_calling_module()
return type(self)(
self.__name__,
*self.values,
shift=-number,
name_calling_module=name_calling_module,
)
def short_repr(self):
long_repr = repr(self)
replaced_by = {
'"': "",
"(": "I",
")": "I",
" - ": "m",
" + ": "p",
", ": "_",
}
for replaced, replacer in replaced_by.items():
long_repr = long_repr.replace(replaced, replacer)
return long_repr
class UnionVar(TemplateVar):
"""TemplateVar used for the Union type"""
_type_values = (type, type(None))
_letter = "U"
class Meta(type, FusedType):
"""Type of the Transonic types (used to create metaclasses)"""
def __call__(cls, *args, **kwargs):
raise RuntimeError("Transonic types are not meant to be instantiated")
def is_fused_type(self):
template_parameters = self.get_template_parameters()
for template_parameter in template_parameters:
if hasattr(template_parameter, "is_fused_type"):
if template_parameter.is_fused_type():
return True
if hasattr(template_parameter, "has_multiple_values"):
if template_parameter.has_multiple_values():
return True
return False
class MemLayout(Enum):
C = auto()
F = auto()
C_or_F = auto()
strided = auto()
def __repr__(self):
return f'"{self.name}"'
def str2shape(str_shape):
assert str_shape.startswith("[") and str_shape.endswith("]")
str_shape = str_shape.replace(" ", "")
if str_shape == "[]":
return (None,)
n = str_shape.count("]")
if n > 1:
return (None,) * n
shape = []
for symbol in str_shape[1:-1].split(","):
if symbol == ":":
value = None
elif symbol == "":
continue
else:
value = int(symbol)
shape.append(value)
return tuple(shape)
def shape2str(shape):
symbols = [":" if value is None else str(value) for value in shape]
tmp = ",".join(symbols)
return f'"[{tmp}]"'
[docs]class Array(metaclass=ArrayMeta):
"""Represent a Numpy array.
>>> Array[int, "2d"]
Array[int, "2d"]
>>> Array[int, "2d", "C"]
Array[int, "2d", "C"]
>>> Array[int, "2d", "F"]
Array[int, "2d", "F"]
>>> Array[int, "2d", "strided"]
Array[int, "2d", "strided"]
Fused types:
>>> Array[Type(int, float), "1d"]
Array[Type(int, float), "1d"]
>>> Array[float, NDim(2, 3)]
Array[float, NDim(2, 3)]
>>> Array[int, "1d", "C", "positive_indices"]
Array[int, "1d", "C", "positive_indices"]
"""
class UnionMeta(Meta):
"""Metaclass for the Union class"""
def __getitem__(self, types):
types_in = types
if not isinstance(types_in, tuple):
types_in = (types_in,)
types = []
for type_ in types_in:
if isinstance(type_, str):
type_ = str2type(type_)
types.append(type_)
types = tuple(types)
name_calling_module = get_name_calling_module()
template_var = UnionVar(*types, name_calling_module=name_calling_module)
short_repr = []
for value in types:
if hasattr(value, "short_repr"):
short_repr.append(value.short_repr())
elif hasattr(value, "__name__"):
short_repr.append(value.__name__)
else:
short_repr.append(repr(value))
return type(
f"Union{'_'.join(short_repr)}",
(Union,),
{"types": types, "template_var": template_var},
)
def get_template_parameters(self):
template_params = []
for type_ in self.types:
if hasattr(type_, "get_template_parameters"):
template_params.extend(type_.get_template_parameters())
template_params.append(self.template_var)
return tuple(template_params)
def __repr__(self):
strings = []
if not hasattr(self, "types"):
return super().__repr__()
for p in self.types:
if isinstance(p, Meta):
string = repr(p)
elif isinstance(p, type):
string = p.__name__
else:
string = repr(p)
strings.append(string)
return "Union[" + ", ".join(strings) + "]"
def format_as_backend_type(self, backend_type_formatter, **kwargs):
type_ = kwargs.pop(self.template_var.__name__)
return format_type_as_backend_type(
type_, backend_type_formatter, **kwargs
)
def short_repr(self):
return self.__name__
[docs]class Union(metaclass=UnionMeta):
"""Similar to typing.Union
>>> Union[float, Array[int, "1d"]]
Union[float, Array[int, "1d"]]
"""
[docs]class List(metaclass=ListMeta):
"""Similar to typing.List
>>> List[List[int]]
List[List[int]]
"""
[docs]class Dict(metaclass=DictMeta):
"""Similar to typing.Dict
>>> Dict[str, int]
Dict[str, int]
"""
class SetMeta(Meta):
"""Metaclass for the Set class"""
def __getitem__(self, type_keys):
if isinstance(type_keys, str):
type_keys = str2type(type_keys)
return type("SetBis", (Set,), {"type_keys": type_keys})
def get_template_parameters(self):
if hasattr(self.type_keys, "get_template_parameters"):
return self.type_keys.get_template_parameters()
else:
return tuple()
def __repr__(self):
if not hasattr(self, "type_keys"):
return super().__repr__()
if isinstance(self.type_keys, type):
key = self.type_keys.__name__
else:
key = repr(self.type_keys)
return f"Set[{key}]"
def format_as_backend_type(self, backend_type_formatter, **kwargs):
return backend_type_formatter.make_set_code(self.type_keys, **kwargs)
[docs]class Set(metaclass=SetMeta):
"""Similar to typing.Set
>>> Set[str]
Set[str]
"""
class TupleMeta(Meta):
"""Metaclass for the Tuple class"""
def __getitem__(self, types):
if not isinstance(types, tuple):
types = (types,)
trans_types = []
for type_in in types:
if isinstance(type_in, str):
type_in(str2type(type_in))
trans_types.append(type_in)
return type("TupleBis", (Tuple,), {"types": trans_types})
def get_template_parameters(self):
template_params = []
for type_ in self.types:
if hasattr(type_, "get_template_parameters"):
template_params.extend(type_.get_template_parameters())
return tuple(template_params)
def __repr__(self):
if not hasattr(self, "types"):
return super().__repr__()
strings = []
for type_ in self.types:
if isinstance(type_, Meta):
name = repr(type_)
elif isinstance(type_, type):
name = type_.__name__
else:
name = repr(type_)
strings.append(name)
return f"Tuple[{', '.join(strings)}]"
def format_as_backend_type(self, backend_type_formatter, **kwargs):
return backend_type_formatter.make_tuple_code(self.types, **kwargs)
[docs]class Tuple(metaclass=TupleMeta):
"""Similar to typing.Tuple
>>> Tuple[int, Array[int, "2d"]]
Tuple[int, Array[int, "2d"]]
"""
class OptionalMeta(Meta):
def __getitem__(self, type_):
return Union[type_, None]
class Optional(metaclass=OptionalMeta):
"""Similar to typing.Optional
>>> Optional[int]
Union[int, None]
"""
[docs]def str2type(str_type):
"""Compute a Transonic type from a string
>>> str2type("int[:,:]")
Array[int, "2d"]
>>> str2type("int or float[]")
Union[int, Array[float, "1d"]]
>>> str2type("(int, float[:, :])")
Tuple[int, Array[float, "2d"]]
"""
str_type = str_type.strip()
if " or " in str_type:
subtypes = str_type.split(" or ")
return Union[tuple(str2type(subtype) for subtype in subtypes)]
try:
return eval(str_type)
except (TypeError, SyntaxError, NameError):
# not a simple type
pass
if "[" not in str_type:
# could be a numpy type
try:
if not str_type.startswith("np."):
dtype = "np." + str_type
else:
dtype = str_type
return eval(dtype, {"np": np})
except (TypeError, SyntaxError, AttributeError):
pass
if str_type.startswith("(") and str_type.endswith(")"):
re_comma = re.compile(r",(?![^\[]*\])(?![^\(]*\))")
return Tuple[
tuple(
str2type(word) for word in re_comma.split(str_type[1:-1]) if word
)
]
words = [word for word in str_type.split(" ") if word]
if words[-1] == "list":
return List[" ".join(words[:-1])]
if words[-1] == "dict":
if len(words) != 3:
raise NotImplementedError(f"words: {words}")
key = words[0][:-1]
value = words[1]
return Dict[key, value]
if words[-1] == "set":
if len(words) != 2:
raise NotImplementedError(f"words: {words}")
key = words[0]
return Set[key]
# str_type should be of the form "int[]"
if "[" not in str_type:
raise ValueError(f"Can't determine the Transonic type from '{str_type}'")
dtype, str_shape = str_type.split("[", 1)
dtype = dtype.strip()
if not dtype.startswith("np.") and dtype not in ("int", "float"):
dtype = "np." + dtype
str_shape = "[" + str_shape
dtype = eval(dtype, {"np": np})
return Array[dtype, str_shape]
_simple_types = (int, float, complex, str)
[docs]def typeof(obj):
"""Compute the Transonic type corresponding to a Python object
Supports:
- simple Python types (int, float, complex, str)
- homogeneous list, dict and set
- tuple
- numpy scalars
- numpy arrays
"""
if isinstance(obj, _simple_types):
return type(obj)
if isinstance(obj, tuple):
return Tuple[tuple(typeof(elem) for elem in obj)]
if isinstance(obj, (list, dict, set)) and not obj:
raise ValueError(
f"Cannot determine the full type of an empty {type(obj)}"
)
if isinstance(obj, list):
type_elem = type(obj[0])
if not all(isinstance(elem, type_elem) for elem in obj):
raise ValueError(f"The list {obj} is not homogeneous in type")
return List[typeof(obj[0])]
if isinstance(obj, (dict, set)):
key = next(iter(obj))
type_key = type(key)
if not all(isinstance(key, type_key) for key in obj):
raise ValueError("The dict {obj} is not homogeneous in type")
if isinstance(obj, dict):
value = next(iter(obj.values()))
type_value = type(value)
if not all(isinstance(value, type_value) for value in obj.values()):
raise ValueError("The dict {obj} is not homogeneous in type")
return Dict[typeof(key), typeof(value)]
else:
return Set[typeof(key)]
# TODO: Tuple
if isinstance(obj, tuple):
raise NotImplementedError
if isinstance(obj, np.ndarray):
if np.isscalar(obj):
return obj.dtype.type
# TODO: deeper analysis
return Array[obj.dtype, f"{obj.ndim}d"]
if isinstance(obj, np.generic):
return type(obj)
raise NotImplementedError(
f"Not able to determine the full type of {obj} (of type {type(obj)})"
)
[docs]class ConstType(Type):
"""Private API class for const"""
def __init__(self, type_):
self.type = type_
def format_as_backend_type(self, backend_type_formatter, **kwargs):
return backend_type_formatter.make_const_code(
format_type_as_backend_type(
self.type, backend_type_formatter, **kwargs
)
)
def __repr__(self):
return f"const({repr(self.type)})"
def is_fused_type(self):
return self.type.is_fused_type()
def get_template_parameters(self):
return self.type.get_template_parameters()
def short_repr(self):
if hasattr(self.type, "short_repr"):
short_repr_type = self.type.short_repr()
else:
short_repr_type = repr(self.type)
return f"constI{short_repr_type}I"
[docs]def const(type_):
"""Declare a type as constant (``const`` C/Cython keyword)"""
return ConstType(type_)