Source code for arkouda.pandas.join

from typing import (
    TYPE_CHECKING,
    Callable,
    Optional,
    Sequence,
    Tuple,
    TypeVar,
    Union,
    cast,
)

import numpy as np

from typeguard import typechecked

from arkouda.numpy.dtypes import NUMBER_FORMAT_STRINGS, resolve_scalar_dtype
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.pdarrayclass import create_pdarray, pdarray
from arkouda.numpy.pdarraysetops import concatenate, in1d
from arkouda.pandas.categorical import Categorical
from arkouda.pandas.groupbyclass import GroupBy, broadcast, groupable_element_type


if TYPE_CHECKING:
    from arkouda.numpy.strings import Strings
else:
    Strings = TypeVar("Strings")

__all__ = ["join_on_eq_with_dt", "gen_ranges", "compute_join_size"]

predicates = {"true_dt": 0, "abs_dt": 1, "pos_dt": 2}


[docs] @typechecked def join_on_eq_with_dt( a1: pdarray, a2: pdarray, t1: pdarray, t2: pdarray, dt: Union[int, np.int64], pred: str, result_limit: Union[int, np.int64] = 1000, ) -> Tuple[pdarray, pdarray]: """ Inner-join on equality between two integer arrays where the time-window predicate is also true. Parameters ---------- a1 : pdarray Values to join (must be int64 dtype). a2 : pdarray Values to join (must be int64 dtype). t1 : pdarray timestamps in millis corresponding to the a1 pdarray t2 : pdarray timestamps in millis corresponding to the a2 pdarray dt : Union[int,np.int64] time delta pred : str time window predicate result_limit : Union[int,np.int64] size limit for returned result Returns ------- Tuple[pdarray, pdarray] result_array_one : pdarray, int64 a1 indices where a1 == a2 result_array_one : pdarray, int64 a2 indices where a2 == a1 Raises ------ TypeError Raised if a1, a2, t1, or t2 is not a pdarray, or if dt or result_limit is not an int ValueError if a1, a2, t1, or t2 dtype is not int64, pred is not 'true_dt', 'abs_dt', or 'pos_dt', or result_limit is < 0 """ from arkouda.core.client import generic_msg if not (a1.dtype == akint64): raise ValueError("a1 must be int64 dtype") if not (a2.dtype == akint64): raise ValueError("a2 must be int64 dtype") if not (t1.dtype == akint64): raise ValueError("t1 must be int64 dtype") if not (t2.dtype == akint64): raise ValueError("t2 must be int64 dtype") if pred not in predicates.keys(): raise ValueError(f"pred must be one of {predicates.keys()}") if result_limit < 0: raise ValueError("the result_limit must 0 or greater") # format numbers for request message dttype = resolve_scalar_dtype(dt) dtstr = NUMBER_FORMAT_STRINGS[dttype].format(dt) predtype = resolve_scalar_dtype(predicates[pred]) predstr = NUMBER_FORMAT_STRINGS[predtype].format(predicates[pred]) result_limittype = resolve_scalar_dtype(result_limit) result_limitstr = NUMBER_FORMAT_STRINGS[result_limittype].format(result_limit) # groupby on a2 g2 = GroupBy(a2) # pass result into server joinEqWithDT operation rep_msg = generic_msg( cmd="joinEqWithDT", args={ "a1": a1, "g2seg": cast(pdarray, g2.segments), "g2keys": cast(pdarray, g2.unique_keys), "g2perm": g2.permutation, "t1": t1, "t2": t2, "dt": dtstr, "pred": predstr, "resLimit": result_limitstr, }, ) # create pdarrays for results res_i_attr, res_j_attr = cast(str, rep_msg).split("+") res_i = create_pdarray(res_i_attr) res_j = create_pdarray(res_j_attr) return res_i, res_j
[docs] def gen_ranges(starts, ends, stride=1, return_lengths=False): """ Generate a segmented array of variable-length, contiguous ranges between pairs of start- and end-points. Parameters ---------- starts : pdarray, int64 The start value of each range ends : pdarray, int64 The end value (exclusive) of each range stride: int Difference between successive elements of each range return_lengths: bool, optional Whether or not to return the lengths of each segment. Default False. Returns ------- pdarray|int64, pdarray|int64, pdarray|int64 segments : pdarray, int64 The starting index of each range in the resulting array ranges : pdarray, int64 The actual ranges, flattened into a single array lengths : pdarray, int64 The lengths of each segment. Only returned if return_lengths=True. """ from arkouda.numpy import cumsum from arkouda.numpy.pdarraycreation import array, ones, zeros if starts.size != ends.size: raise ValueError("starts and ends must be same length") if starts.size == 0: return zeros(0, dtype=akint64), zeros(0, dtype=akint64) lengths = (ends - starts) // stride if not (lengths >= 0).all(): raise ValueError("all ends must be greater than or equal to starts") non_empty = lengths != 0 segs = cumsum(lengths) - lengths totlen = lengths.sum() slices = ones(totlen, dtype=akint64) non_empty_starts = starts[non_empty] non_empty_lengths = lengths[non_empty] diffs = concatenate( ( array([non_empty_starts[0]]), non_empty_starts[1:] - non_empty_starts[:-1] - (non_empty_lengths[:-1] - 1) * stride, ) ) slices[segs[non_empty]] = diffs sums = cumsum(slices) if return_lengths: return segs, sums, lengths else: return segs, sums
[docs] @typechecked def compute_join_size(a: pdarray, b: pdarray) -> Tuple[int, int]: """ Compute the internal size of a hypothetical join between a and b. Returns both the number of elements and number of bytes required for the join. """ bya = GroupBy(a) ua, asize = bya.size() byb = GroupBy(b) ub, bsize = byb.size() afact = asize[in1d(cast(groupable_element_type, ua), cast(groupable_element_type, ub))] bfact = bsize[in1d(cast(groupable_element_type, ub), cast(groupable_element_type, ua))] nelem = (afact * bfact).sum() nbytes = 3 * 8 * nelem return nelem, nbytes
@typechecked def inner_join( left: Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], right: Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], wherefunc: Optional[Callable] = None, whereargs: Optional[ Tuple[Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], ...] ] = None, ) -> Tuple[pdarray, pdarray]: """ Perform inner join on values in <left> and <right>, using conditions defined by <wherefunc> evaluated on <whereargs>, returning indices of left-right pairs. Parameters ---------- left : pdarray(int64), Strings, Categorical, or Sequence of pdarray The left values to join right : pdarray(int64), Strings, Categorical, or Sequence of pdarray The right values to join wherefunc : function, optional Function that takes two pdarray arguments and returns a pdarray(bool) used to filter the join. Results for which wherefunc is False will be dropped. whereargs : 2-tuple of pdarray, Strings, Categorical, or Sequence of pdarray, optional The two arguments for wherefunc Returns ------- Tuple[pdarray, pdarray] left_inds : pdarray(int64) The left indices of pairs that meet the join condition right_inds : pdarray(int64) The right indices of pairs that meet the join condition Notes ----- The return values satisfy the following assertions `assert (left[left_inds] == right[right_inds]).all()` `assert wherefunc(whereargs[0][left_inds], whereargs[1][right_inds]).all()` """ from inspect import signature from arkouda.numpy import cumsum from arkouda.numpy.alignment import right_align from arkouda.numpy.pdarraycreation import arange, array, zeros is_sequence = isinstance(left, Sequence) and isinstance(right, Sequence) # Reduce processing to codes to prevent groupby on entire Categorical if isinstance(left, Categorical) and isinstance(right, Categorical): lft, r = Categorical.standardize_categories([left, right]) left, right = lft.codes, r.codes if is_sequence: if len(left) != len(right): raise ValueError("Left must have same num arrays as right") left_size, right_size = left[0].size, right[0].size if not all(lf.size == left_size for lf in left) or not all( rt.size == right_size for rt in right ): raise ValueError("Multi-array arguments must have equal-length arrays") else: left_size, right_size = left.size, right.size # type: ignore sample = np.min((left_size, right_size, 5)) # type: ignore if wherefunc is not None: if len(signature(wherefunc).parameters) != 2: raise ValueError("wherefunc must be a function that accepts exactly two arguments") if whereargs is None or len(whereargs) != 2: raise ValueError("whereargs must be a 2-tuple with left and right arg arrays") if is_sequence: if len(whereargs[0]) != len(whereargs[1]): raise ValueError("Left must have same num arrays as right") first_wa_size, second_wa_size = whereargs[0][0].size, whereargs[1][0].size if not all(wa.size == first_wa_size for wa in whereargs[0]) or not all( wa.size == second_wa_size for wa in whereargs[1] ): raise ValueError("Multi-array arguments must have equal-length arrays") else: first_wa_size, second_wa_size = whereargs[0].size, whereargs[1].size # type: ignore if first_wa_size != left_size: raise ValueError("Left whereargs must be same size as left join values") if second_wa_size != right_size: raise ValueError("Right whereargs must be same size as right join values") try: _ = wherefunc(whereargs[0][:sample], whereargs[1][:sample]) except Exception as e: raise ValueError("Error evaluating wherefunc") from e # Need dense 0-up right index, to filter out left not in right keep, (dense_left, dense_right) = right_align(left, right) if keep.sum() == 0: # Intersection is empty return zeros(0, dtype=akint64), zeros(0, dtype=akint64) keep = arange(keep.size)[keep] # GroupBy right by_right = GroupBy(dense_right) # Get segment boundaries (starts, ends) of right for each left item right_segs = concatenate((by_right.segments, array([dense_right.size]))) starts = right_segs[dense_left] ends = right_segs[dense_left + 1] # gen_ranges for gather of right items full_segs, ranges = gen_ranges(starts, ends) # Evaluate where clause if wherefunc is None: filt_ranges = ranges filt_segs = full_segs keep12 = keep else: if whereargs is not None: right = whereargs[1] left = whereargs[0] if not isinstance(right, Sequence) and not isinstance(left, Sequence): # Gather right whereargs right_where = right[by_right.permutation][ranges] # Expand left whereargs keep_where = left[keep] keep_where = keep_where.codes if isinstance(keep_where, Categorical) else keep_where left_where = broadcast(full_segs, keep_where, ranges.size) else: # Gather right whereargs right_where = [wa[by_right.permutation][ranges] for wa in whereargs[1]] # Expand left whereargs keep_where = [wa[keep] for wa in whereargs[0]] left_where = [broadcast(full_segs, kw, ranges.size) for kw in keep_where] # Evaluate wherefunc and filter ranges, recompute segments where_satisfied = wherefunc(left_where, right_where) filt_ranges = ranges[where_satisfied] scan = cumsum(where_satisfied) - where_satisfied filt_segs_with_zeros = scan[full_segs] filt_seg_sizes = concatenate( ( filt_segs_with_zeros[1:] - filt_segs_with_zeros[:-1], array([where_satisfied.sum() - filt_segs_with_zeros[-1]]), ) ) keep2 = filt_seg_sizes > 0 filt_segs = filt_segs_with_zeros[keep2] keep12 = keep[keep2] # Gather right inds and expand left inds right_inds = by_right.permutation[filt_ranges] left_inds = broadcast(filt_segs, arange(left_size)[keep12], filt_ranges.size) return left_inds, right_inds