| Viewing file:  _creation_functions.py (9.81 KB)      -rw-r--r-- Select action/file-type:
 
  (+) |  (+) |  (+) | Code (+) | Session (+) |  (+) | SDB (+) |  (+) |  (+) |  (+) |  (+) |  (+) | 
 
from __future__ import annotations
 
 from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 if TYPE_CHECKING:
 from ._typing import (
 Array,
 Device,
 Dtype,
 NestedSequence,
 SupportsBufferProtocol,
 )
 from collections.abc import Sequence
 from ._dtypes import _all_dtypes
 
 import numpy as np
 
 
 def _check_valid_dtype(dtype):
 # Note: Only spelling dtypes as the dtype objects is supported.
 
 # We use this instead of "dtype in _all_dtypes" because the dtype objects
 # define equality with the sorts of things we want to disallow.
 for d in (None,) + _all_dtypes:
 if dtype is d:
 return
 raise ValueError("dtype must be one of the supported dtypes")
 
 
 def asarray(
 obj: Union[
 Array,
 bool,
 int,
 float,
 NestedSequence[bool | int | float],
 SupportsBufferProtocol,
 ],
 /,
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 copy: Optional[Union[bool, np._CopyMode]] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
 
 See its docstring for more information.
 """
 # _array_object imports in this file are inside the functions to avoid
 # circular imports
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 if copy in (False, np._CopyMode.IF_NEEDED):
 # Note: copy=False is not yet implemented in np.asarray
 raise NotImplementedError("copy=False is not yet implemented")
 if isinstance(obj, Array):
 if dtype is not None and obj.dtype != dtype:
 copy = True
 if copy in (True, np._CopyMode.ALWAYS):
 return Array._new(np.array(obj._array, copy=True, dtype=dtype))
 return obj
 if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
 # Give a better error message in this case. NumPy would convert this
 # to an object array. TODO: This won't handle large integers in lists.
 raise OverflowError("Integer out of bounds for array dtypes")
 res = np.asarray(obj, dtype=dtype)
 return Array._new(res)
 
 
 def arange(
 start: Union[int, float],
 /,
 stop: Optional[Union[int, float]] = None,
 step: Union[int, float] = 1,
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))
 
 
 def empty(
 shape: Union[int, Tuple[int, ...]],
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.empty(shape, dtype=dtype))
 
 
 def empty_like(
 x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.empty_like(x._array, dtype=dtype))
 
 
 def eye(
 n_rows: int,
 n_cols: Optional[int] = None,
 /,
 *,
 k: int = 0,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))
 
 
 def from_dlpack(x: object, /) -> Array:
 from ._array_object import Array
 
 return Array._new(np.from_dlpack(x))
 
 
 def full(
 shape: Union[int, Tuple[int, ...]],
 fill_value: Union[int, float],
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.full <numpy.full>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 if isinstance(fill_value, Array) and fill_value.ndim == 0:
 fill_value = fill_value._array
 res = np.full(shape, fill_value, dtype=dtype)
 if res.dtype not in _all_dtypes:
 # This will happen if the fill value is not something that NumPy
 # coerces to one of the acceptable dtypes.
 raise TypeError("Invalid input to full")
 return Array._new(res)
 
 
 def full_like(
 x: Array,
 /,
 fill_value: Union[int, float],
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 res = np.full_like(x._array, fill_value, dtype=dtype)
 if res.dtype not in _all_dtypes:
 # This will happen if the fill value is not something that NumPy
 # coerces to one of the acceptable dtypes.
 raise TypeError("Invalid input to full_like")
 return Array._new(res)
 
 
 def linspace(
 start: Union[int, float],
 stop: Union[int, float],
 /,
 num: int,
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 endpoint: bool = True,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))
 
 
 def meshgrid(*arrays: Array, indexing: str = "xy") -> List[Array]:
 """
 Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 # Note: unlike np.meshgrid, only inputs with all the same dtype are
 # allowed
 
 if len({a.dtype for a in arrays}) > 1:
 raise ValueError("meshgrid inputs must all have the same dtype")
 
 return [
 Array._new(array)
 for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)
 ]
 
 
 def ones(
 shape: Union[int, Tuple[int, ...]],
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.ones(shape, dtype=dtype))
 
 
 def ones_like(
 x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.ones_like(x._array, dtype=dtype))
 
 
 def tril(x: Array, /, *, k: int = 0) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.tril <numpy.tril>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 if x.ndim < 2:
 # Note: Unlike np.tril, x must be at least 2-D
 raise ValueError("x must be at least 2-dimensional for tril")
 return Array._new(np.tril(x._array, k=k))
 
 
 def triu(x: Array, /, *, k: int = 0) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.triu <numpy.triu>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 if x.ndim < 2:
 # Note: Unlike np.triu, x must be at least 2-D
 raise ValueError("x must be at least 2-dimensional for triu")
 return Array._new(np.triu(x._array, k=k))
 
 
 def zeros(
 shape: Union[int, Tuple[int, ...]],
 *,
 dtype: Optional[Dtype] = None,
 device: Optional[Device] = None,
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.zeros(shape, dtype=dtype))
 
 
 def zeros_like(
 x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None
 ) -> Array:
 """
 Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.
 
 See its docstring for more information.
 """
 from ._array_object import Array
 
 _check_valid_dtype(dtype)
 if device not in ["cpu", None]:
 raise ValueError(f"Unsupported device {device!r}")
 return Array._new(np.zeros_like(x._array, dtype=dtype))
 
 |