import textwrap from typing import List, Set from pandas._libs import NaT, lib from pandas.errors import InvalidIndexError import pandas.core.common as com from pandas.core.indexes.base import ( Index, _new_Index, ensure_index, ensure_index_from_sequences, ) from pandas.core.indexes.category import CategoricalIndex from pandas.core.indexes.datetimes import DatetimeIndex from pandas.core.indexes.interval import IntervalIndex from pandas.core.indexes.multi import MultiIndex from pandas.core.indexes.numeric import ( Float64Index, Int64Index, NumericIndex, UInt64Index, ) from pandas.core.indexes.period import PeriodIndex from pandas.core.indexes.range import RangeIndex from pandas.core.indexes.timedeltas import TimedeltaIndex _sort_msg = textwrap.dedent( """\ Sorting because non-concatenation axis is not aligned. A future version of pandas will change to not sort by default. To accept the future behavior, pass 'sort=False'. To retain the current behavior and silence the warning, pass 'sort=True'. """ ) __all__ = [ "Index", "MultiIndex", "NumericIndex", "Float64Index", "Int64Index", "CategoricalIndex", "IntervalIndex", "RangeIndex", "UInt64Index", "InvalidIndexError", "TimedeltaIndex", "PeriodIndex", "DatetimeIndex", "_new_Index", "NaT", "ensure_index", "ensure_index_from_sequences", "get_objs_combined_axis", "union_indexes", "get_consensus_names", "all_indexes_same", ] def get_objs_combined_axis( objs, intersect: bool = False, axis=0, sort: bool = True, copy: bool = False ) -> Index: """ Extract combined index: return intersection or union (depending on the value of "intersect") of indexes on given axis, or None if all objects lack indexes (e.g. they are numpy arrays). Parameters ---------- objs : list Series or DataFrame objects, may be mix of the two. intersect : bool, default False If True, calculate the intersection between indexes. Otherwise, calculate the union. axis : {0 or 'index', 1 or 'outer'}, default 0 The axis to extract indexes from. sort : bool, default True Whether the result index should come out sorted or not. copy : bool, default False If True, return a copy of the combined index. Returns ------- Index """ obs_idxes = [obj._get_axis(axis) for obj in objs] return _get_combined_index(obs_idxes, intersect=intersect, sort=sort, copy=copy) def _get_distinct_objs(objs: List[Index]) -> List[Index]: """ Return a list with distinct elements of "objs" (different ids). Preserves order. """ ids: Set[int] = set() res = [] for obj in objs: if id(obj) not in ids: ids.add(id(obj)) res.append(obj) return res def _get_combined_index( indexes: List[Index], intersect: bool = False, sort: bool = False, copy: bool = False, ) -> Index: """ Return the union or intersection of indexes. Parameters ---------- indexes : list of Index or list objects When intersect=True, do not accept list of lists. intersect : bool, default False If True, calculate the intersection between indexes. Otherwise, calculate the union. sort : bool, default False Whether the result index should come out sorted or not. copy : bool, default False If True, return a copy of the combined index. Returns ------- Index """ # TODO: handle index names! indexes = _get_distinct_objs(indexes) if len(indexes) == 0: index = Index([]) elif len(indexes) == 1: index = indexes[0] elif intersect: index = indexes[0] for other in indexes[1:]: index = index.intersection(other) else: index = union_indexes(indexes, sort=sort) index = ensure_index(index) if sort: try: index = index.sort_values() except TypeError: pass # GH 29879 if copy: index = index.copy() return index def union_indexes(indexes, sort=True) -> Index: """ Return the union of indexes. The behavior of sort and names is not consistent. Parameters ---------- indexes : list of Index or list objects sort : bool, default True Whether the result index should come out sorted or not. Returns ------- Index """ if len(indexes) == 0: raise AssertionError("Must have at least 1 Index to union") if len(indexes) == 1: result = indexes[0] if isinstance(result, list): result = Index(sorted(result)) return result indexes, kind = _sanitize_and_check(indexes) def _unique_indices(inds) -> Index: """ Convert indexes to lists and concatenate them, removing duplicates. The final dtype is inferred. Parameters ---------- inds : list of Index or list objects Returns ------- Index """ def conv(i): if isinstance(i, Index): i = i.tolist() return i return Index(lib.fast_unique_multiple_list([conv(i) for i in inds], sort=sort)) if kind == "special": result = indexes[0] if hasattr(result, "union_many"): # DatetimeIndex return result.union_many(indexes[1:]) else: for other in indexes[1:]: result = result.union(other) return result elif kind == "array": index = indexes[0] for other in indexes[1:]: if not index.equals(other): return _unique_indices(indexes) name = get_consensus_names(indexes)[0] if name != index.name: index = index._shallow_copy(name=name) return index else: # kind='list' return _unique_indices(indexes) def _sanitize_and_check(indexes): """ Verify the type of indexes and convert lists to Index. Cases: - [list, list, ...]: Return ([list, list, ...], 'list') - [list, Index, ...]: Return _sanitize_and_check([Index, Index, ...]) Lists are sorted and converted to Index. - [Index, Index, ...]: Return ([Index, Index, ...], TYPE) TYPE = 'special' if at least one special type, 'array' otherwise. Parameters ---------- indexes : list of Index or list objects Returns ------- sanitized_indexes : list of Index or list objects type : {'list', 'array', 'special'} """ kinds = list({type(index) for index in indexes}) if list in kinds: if len(kinds) > 1: indexes = [ Index(list(x)) if not isinstance(x, Index) else x for x in indexes ] kinds.remove(list) else: return indexes, "list" if len(kinds) > 1 or Index not in kinds: return indexes, "special" else: return indexes, "array" def get_consensus_names(indexes): """ Give a consensus 'names' to indexes. If there's exactly one non-empty 'names', return this, otherwise, return empty. Parameters ---------- indexes : list of Index objects Returns ------- list A list representing the consensus 'names' found. """ # find the non-none names, need to tupleify to make # the set hashable, then reverse on return consensus_names = {tuple(i.names) for i in indexes if com.any_not_none(*i.names)} if len(consensus_names) == 1: return list(list(consensus_names)[0]) return [None] * indexes[0].nlevels def all_indexes_same(indexes): """ Determine if all indexes contain the same elements. Parameters ---------- indexes : iterable of Index objects Returns ------- bool True if all indexes contain the same elements, False otherwise. """ itr = iter(indexes) first = next(itr) for index in itr: if not first.equals(index): return False return True