Source code for arkouda.join

from typing import Callable, Optional, Sequence, Tuple, Union, cast

import numpy as np
from typeguard import typechecked

from arkouda.alignment import right_align
from arkouda.categorical import Categorical
from arkouda.client import generic_msg
from arkouda.numpy.dtypes import NUMBER_FORMAT_STRINGS
from arkouda.numpy.dtypes import int64 as akint64
from arkouda.numpy.dtypes import resolve_scalar_dtype
from arkouda.groupbyclass import GroupBy, broadcast
from arkouda.numeric import cumsum
from arkouda.pdarrayclass import create_pdarray, pdarray
from arkouda.pdarraycreation import arange, array, ones, zeros
from arkouda.pdarraysetops import concatenate, in1d
from arkouda.strings import Strings

__all__ = ["join_on_eq_with_dt", "gen_ranges", "compute_join_size"]

predicates = {"true_dt": 0, "abs_dt": 1, "pos_dt": 2}


[docs] @typechecked def join_on_eq_with_dt( a1: pdarray, a2: pdarray, t1: pdarray, t2: pdarray, dt: Union[int, np.int64], pred: str, result_limit: Union[int, np.int64] = 1000, ) -> Tuple[pdarray, pdarray]: """ Performs an inner-join on equality between two integer arrays where the time-window predicate is also true Parameters ---------- a1 : pdarray, int64 pdarray to be joined a2 : pdarray, int64 pdarray to be joined t1 : pdarray timestamps in millis corresponding to the a1 pdarray t2 : pdarray timestamps in millis corresponding to the a2 pdarray dt : Union[int,np.int64] time delta pred : str time window predicate result_limit : Union[int,np.int64] size limit for returned result Returns ------- result_array_one : pdarray, int64 a1 indices where a1 == a2 result_array_one : pdarray, int64 a2 indices where a2 == a1 Raises ------ TypeError Raised if a1, a2, t1, or t2 is not a pdarray, or if dt or result_limit is not an int ValueError if a1, a2, t1, or t2 dtype is not int64, pred is not 'true_dt', 'abs_dt', or 'pos_dt', or result_limit is < 0 """ if not (a1.dtype == akint64): raise ValueError("a1 must be int64 dtype") if not (a2.dtype == akint64): raise ValueError("a2 must be int64 dtype") if not (t1.dtype == akint64): raise ValueError("t1 must be int64 dtype") if not (t2.dtype == akint64): raise ValueError("t2 must be int64 dtype") if not (pred in predicates.keys()): raise ValueError(f"pred must be one of {predicates.keys()}") if result_limit < 0: raise ValueError("the result_limit must 0 or greater") # format numbers for request message dttype = resolve_scalar_dtype(dt) dtstr = NUMBER_FORMAT_STRINGS[dttype].format(dt) predtype = resolve_scalar_dtype(predicates[pred]) predstr = NUMBER_FORMAT_STRINGS[predtype].format(predicates[pred]) result_limittype = resolve_scalar_dtype(result_limit) result_limitstr = NUMBER_FORMAT_STRINGS[result_limittype].format(result_limit) # groupby on a2 g2 = GroupBy(a2) # pass result into server joinEqWithDT operation repMsg = generic_msg( cmd="joinEqWithDT", args={ "a1": a1, "g2seg": cast(pdarray, g2.segments), "g2keys": cast(pdarray, g2.unique_keys), "g2perm": g2.permutation, "t1": t1, "t2": t2, "dt": dtstr, "pred": predstr, "resLimit": result_limitstr, }, ) # create pdarrays for results resIAttr, resJAttr = cast(str, repMsg).split("+") resI = create_pdarray(resIAttr) resJ = create_pdarray(resJAttr) return resI, resJ
[docs] def gen_ranges(starts, ends, stride=1, return_lengths=False): """ Generate a segmented array of variable-length, contiguous ranges between pairs of start- and end-points. Parameters ---------- starts : pdarray, int64 The start value of each range ends : pdarray, int64 The end value (exclusive) of each range stride: int Difference between successive elements of each range return_lengths: bool, optional Whether or not to return the lengths of each segment. Default False. Returns ------- segments : pdarray, int64 The starting index of each range in the resulting array ranges : pdarray, int64 The actual ranges, flattened into a single array lengths : pdarray, int64 The lengths of each segment. Only returned if return_lengths=True. """ if starts.size != ends.size: raise ValueError("starts and ends must be same length") if starts.size == 0: return zeros(0, dtype=akint64), zeros(0, dtype=akint64) lengths = (ends - starts) // stride if not (lengths >= 0).all(): raise ValueError("all ends must be greater than or equal to starts") non_empty = lengths != 0 segs = cumsum(lengths) - lengths totlen = lengths.sum() slices = ones(totlen, dtype=akint64) non_empty_starts = starts[non_empty] non_empty_lengths = lengths[non_empty] diffs = concatenate( ( array([non_empty_starts[0]]), non_empty_starts[1:] - non_empty_starts[:-1] - (non_empty_lengths[:-1] - 1) * stride, ) ) slices[segs[non_empty]] = diffs sums = cumsum(slices) if return_lengths: return segs, sums, lengths else: return segs, sums
[docs] @typechecked def compute_join_size(a: pdarray, b: pdarray) -> Tuple[int, int]: """ Compute the internal size of a hypothetical join between a and b. Returns both the number of elements and number of bytes required for the join. """ bya = GroupBy(a) ua, asize = bya.size() byb = GroupBy(b) ub, bsize = byb.size() afact = asize[in1d(ua, ub)] bfact = bsize[in1d(ub, ua)] nelem = (afact * bfact).sum() nbytes = 3 * 8 * nelem return nelem, nbytes
@typechecked def inner_join( left: Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], right: Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], wherefunc: Optional[Callable] = None, whereargs: Optional[ Tuple[ Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], Union[pdarray, Strings, Categorical, Sequence[Union[pdarray, Strings]]], ] ] = None, ) -> Tuple[pdarray, pdarray]: """Perform inner join on values in <left> and <right>, using conditions defined by <wherefunc> evaluated on <whereargs>, returning indices of left-right pairs. Parameters ---------- left : pdarray(int64), Strings, Categorical, or Sequence of pdarray The left values to join right : pdarray(int64), Strings, Categorical, or Sequence of pdarray The right values to join wherefunc : function, optional Function that takes two pdarray arguments and returns a pdarray(bool) used to filter the join. Results for which wherefunc is False will be dropped. whereargs : 2-tuple of pdarray, Strings, Categorical, or Sequence of pdarray, optional The two arguments for wherefunc Returns ------- leftInds : pdarray(int64) The left indices of pairs that meet the join condition rightInds : pdarray(int64) The right indices of pairs that meet the join condition Notes ----- The return values satisfy the following assertions `assert (left[leftInds] == right[rightInds]).all()` `assert wherefunc(whereargs[0][leftInds], whereargs[1][rightInds]).all()` """ from inspect import signature is_sequence = isinstance(left, Sequence) and isinstance(right, Sequence) # Reduce processing to codes to prevent groupby on entire Categorical if isinstance(left, Categorical) and isinstance(right, Categorical): l, r = Categorical.standardize_categories([left, right]) left, right = l.codes, r.codes if is_sequence: if len(left) != len(right): raise ValueError("Left must have same num arrays as right") left_size, right_size = left[0].size, right[0].size if not all(lf.size == left_size for lf in left) or not all( rt.size == right_size for rt in right ): raise ValueError("Multi-array arguments must have equal-length arrays") else: left_size, right_size = left.size, right.size # type: ignore sample = np.min((left_size, right_size, 5)) # type: ignore if wherefunc is not None: if len(signature(wherefunc).parameters) != 2: raise ValueError("wherefunc must be a function that accepts exactly two arguments") if whereargs is None or len(whereargs) != 2: raise ValueError("whereargs must be a 2-tuple with left and right arg arrays") if is_sequence: if len(whereargs[0]) != len(whereargs[1]): raise ValueError("Left must have same num arrays as right") first_wa_size, second_wa_size = whereargs[0][0].size, whereargs[1][0].size if not all(wa.size == first_wa_size for wa in whereargs[0]) or not all( wa.size == second_wa_size for wa in whereargs[1] ): raise ValueError("Multi-array arguments must have equal-length arrays") else: first_wa_size, second_wa_size = whereargs[0].size, whereargs[1].size # type: ignore if first_wa_size != left_size: raise ValueError("Left whereargs must be same size as left join values") if second_wa_size != right_size: raise ValueError("Right whereargs must be same size as right join values") try: _ = wherefunc(whereargs[0][:sample], whereargs[1][:sample]) except Exception as e: raise ValueError("Error evaluating wherefunc") from e # Need dense 0-up right index, to filter out left not in right keep, (denseLeft, denseRight) = right_align(left, right) if keep.sum() == 0: # Intersection is empty return zeros(0, dtype=akint64), zeros(0, dtype=akint64) keep = arange(keep.size)[keep] # GroupBy right byRight = GroupBy(denseRight) # Get segment boundaries (starts, ends) of right for each left item rightSegs = concatenate((byRight.segments, array([denseRight.size]))) starts = rightSegs[denseLeft] ends = rightSegs[denseLeft + 1] # gen_ranges for gather of right items fullSegs, ranges = gen_ranges(starts, ends) # Evaluate where clause if wherefunc is None: filtRanges = ranges filtSegs = fullSegs keep12 = keep else: if whereargs is not None: if not is_sequence: # Gather right whereargs rightWhere = whereargs[1][byRight.permutation][ranges] # Expand left whereargs keep_where = whereargs[0][keep] keep_where = keep_where.codes if isinstance(keep_where, Categorical) else keep_where leftWhere = broadcast(fullSegs, keep_where, ranges.size) else: # Gather right whereargs rightWhere = [wa[byRight.permutation][ranges] for wa in whereargs[1]] # Expand left whereargs keep_where = [wa[keep] for wa in whereargs[0]] leftWhere = [broadcast(fullSegs, kw, ranges.size) for kw in keep_where] # Evaluate wherefunc and filter ranges, recompute segments whereSatisfied = wherefunc(leftWhere, rightWhere) filtRanges = ranges[whereSatisfied] scan = cumsum(whereSatisfied) - whereSatisfied filtSegsWithZeros = scan[fullSegs] filtSegSizes = concatenate( ( filtSegsWithZeros[1:] - filtSegsWithZeros[:-1], array([whereSatisfied.sum() - filtSegsWithZeros[-1]]), ) ) keep2 = filtSegSizes > 0 filtSegs = filtSegsWithZeros[keep2] keep12 = keep[keep2] # Gather right inds and expand left inds rightInds = byRight.permutation[filtRanges] leftInds = broadcast(filtSegs, arange(left_size)[keep12], filtRanges.size) return leftInds, rightInds