from __future__ import annotations
from typing import Literal, Optional, Tuple, cast
import arkouda as ak
from arkouda.numpy import cast as akcast
from arkouda.numpy.pdarrayclass import create_pdarray, create_pdarrays
from ._dtypes import _real_floating_dtypes, _real_numeric_dtypes
from .array_object import Array
from .manipulation_functions import broadcast_arrays
__all__ = [
"argmax",
"argmin",
"nonzero",
"searchsorted",
"where",
]
[docs]
def argmax(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
"""
Return an array with the indices of the maximum values along a given axis.
Parameters
----------
x : Array
The array to search for maximum values
axis : int, optional
The axis along which to search for maximum values. If None, the array is flattened before
searching.
keepdims : bool, optional
Whether to keep the singleton dimension along `axis` in the result.
"""
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")
return Array._new(ak.argmax(x._array, axis=axis, keepdims=keepdims))
[docs]
def argmin(x: Array, /, *, axis: Optional[int] = None, keepdims: bool = False) -> Array:
"""
Return an array with the indices of the minimum values along a given axis.
Parameters
----------
x : Array
The array to search for minimum values
axis : int, optional
The axis along which to search for minimum values. If None, the array is flattened before
searching.
keepdims : bool, optional
Whether to keep the singleton dimension along `axis` in the result.
"""
if x.dtype not in _real_numeric_dtypes:
raise TypeError("Only real numeric dtypes are allowed in argmax")
return Array._new(ak.argmin(x._array, axis=axis, keepdims=keepdims))
[docs]
def nonzero(x: Array, /) -> Tuple[Array, ...]:
"""
Return a tuple of arrays containing the indices of the non-zero elements of the input array.
"""
from arkouda.client import generic_msg
resp = cast(
str,
generic_msg(
cmd=f"nonzero<{x.dtype},{x.ndim}>",
args={"x": x._array},
),
)
return tuple([Array._new(a) for a in create_pdarrays(resp)])
[docs]
def where(condition: Array, x1: Array, x2: Array, /) -> Array:
"""
Return elements, either from `x1` or `x2`, depending on `condition`.
Parameters
----------
condition : Array
When condition[i] is True, store x1[i] in the output array, otherwise store x2[i].
x1 : Array
Values selected at indices where `condition` is True.
x2 : Array
Values selected at indices where `condition` is False.
"""
from arkouda.client import generic_msg
broadcasted = broadcast_arrays(condition, x1, x2)
a = broadcasted[1]._array
b = broadcasted[2]._array
c = akcast(broadcasted[0]._array, ak.bool_)
return Array._new(
create_pdarray(
generic_msg(
cmd=f"wherevv<{c.ndim},{a.dtype},{b.dtype}>",
args={
"condition": c,
"a": a,
"b": b,
},
)
)
)
[docs]
def searchsorted(
x1: Array,
x2: Array,
/,
*,
side: Literal["left", "right"] = "left",
sorter: Optional[Array] = None,
) -> Array:
"""
Given a sorted array `x1`, find the indices to insert elements from another array `x2` such that
the sorted order is maintained.
Parameters
----------
x1 : Array
The sorted array to search in.
x2 : Array
The values to search for in `x1`.
side : {'left', 'right'}, optional
If 'left', the index of the first suitable location found is given. If 'right', return the
last such index. Default is 'left'.
sorter : Array, optional
The indices that would sort `x1` in ascending order. If None, `x1` is assumed to be sorted.
"""
from arkouda.client import generic_msg
if x1.dtype not in _real_floating_dtypes or x2.dtype not in _real_floating_dtypes:
raise TypeError("Only real dtypes are allowed in searchsorted")
if x1.ndim > 1:
raise ValueError("searchsorted only supports 1D arrays for x1")
if sorter is not None:
_x1 = x1[sorter]
else:
_x1 = x1
resp = generic_msg(
cmd=f"searchSorted<{x1.dtype},1,{x2.ndim}>",
args={
"x1": _x1._array,
"x2": x2._array,
"side": side,
},
)
return Array._new(create_pdarray(resp))