Source code for glow.hail.functions

# Copyright 2019 The Glow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hail as hl
from hail import MatrixTable
from hail.expr.expressions.typed_expressions import StructExpression
from hail.expr.types import tarray, tbool, tcall, tfloat64, tint, tlocus, tset, tstr, tstruct
from pyspark.sql import Column, DataFrame
import pyspark.sql.functions as fx
from typing import List, NoReturn, Optional
from typeguard import check_argument_types, check_return_type

__all__ = ['from_matrix_table']


def _get_sample_ids(col_key: StructExpression, include_sample_ids: bool) -> Optional[List[str]]:
    assert check_argument_types()

    has_sample_ids = 's' in col_key and col_key.s.dtype == tstr
    if include_sample_ids and has_sample_ids:
        sample_ids = [f"'{s}'" for s in col_key.s.collect()]
    else:
        sample_ids = None

    assert check_return_type(sample_ids)
    return sample_ids


def _get_genotypes_col(entry: StructExpression, sample_ids: Optional[List[str]]) -> Column:
    assert check_argument_types()

    entry_aliases = {
        'DP': 'depth',
        'FT': 'filters',
        'GL': 'genotypeLikelihoods',
        'PL': 'phredLikelihoods',
        'GP': 'posteriorProbabilities',
        'GQ': 'conditionalQuality',
        'HQ': 'haplotypeQualities',
        'EC': 'expectedAlleleCounts',
        'MQ': 'mappingQuality',
        'AD': 'alleleDepths'
    }

    base_struct_args = []
    for entry_field in entry:
        if entry_field == 'GT' and entry.GT.dtype == tcall:
            # Flatten GT into calls and phased
            base_struct_args.append(
                "'calls', ifnull(e.GT.alleles, array(-1,-1)), 'phased', e.GT.phased")
        elif entry[entry_field].dtype == tcall:
            # Turn other call fields (eg. PGT) into a string
            base_struct_args.append(
                f"'{entry_field}', array_join(e.{entry_field}.alleles, if(e.{entry_field}.phased, '|', '/'))"
            )
        elif entry_field in entry_aliases:
            # Rename aliased genotype fields
            base_struct_args.append(f"'{entry_aliases[entry_field]}', e.{entry_field}")
        else:
            # Rename genotype fields
            base_struct_args.append(f"'{entry_field}', e.{entry_field}")

    if sample_ids is not None:
        sample_id_expr = f"array({' ,'.join(sample_ids)})"
        struct_expr = ' ,'.join(["'sampleId', s"] + base_struct_args)
        genotypes_col = fx.expr(
            f"zip_with({sample_id_expr}, entries, (s, e) -> named_struct({struct_expr}))")
    else:
        struct_expr = ' ,'.join(base_struct_args)
        genotypes_col = fx.expr(f"transform(entries, e -> named_struct({struct_expr}))")
    genotypes_col = genotypes_col.alias("genotypes")

    assert check_return_type(genotypes_col)
    return genotypes_col


def _get_base_cols(row: StructExpression) -> List[Column]:
    assert check_argument_types()

    contig_name_col = fx.col("`locus.contig`").alias("contigName")

    start_col = (fx.col("`locus.position`") - 1).cast("long").alias("start")

    end_col = start_col + fx.length(fx.element_at("alleles", 1))
    has_info = 'info' in row and isinstance(row.info.dtype, tstruct)
    if has_info and 'END' in row.info and row.info.END.dtype == tint:
        end_col = fx.coalesce(fx.col("`info.END`"), end_col)
    end_col = end_col.cast("long").alias("end")

    names_elems = []
    if 'varid' in row and row.varid.dtype == tstr:
        names_elems.append("varid")
    if 'rsid' in row and row.rsid.dtype == tstr:
        names_elems.append("rsid")
    names_col = fx.expr(
        f"nullif(filter(array({','.join(names_elems)}), n -> isnotnull(n)), array())").alias("names")

    reference_allele_col = fx.element_at("alleles", 1).alias("referenceAllele")

    alternate_alleles_col = fx.expr("slice(alleles, 2, size(alleles) - 1)").alias("alternateAlleles")

    base_cols = [
        contig_name_col, start_col, end_col, names_col, reference_allele_col, alternate_alleles_col
    ]
    assert check_return_type(base_cols)
    return base_cols


def _get_other_cols(row: StructExpression) -> List[Column]:
    assert check_argument_types()

    other_cols = []
    if 'cm_position' in row and row.cm_position.dtype == tfloat64:
        other_cols.append(fx.col("cm_position").alias("position"))
    if 'qual' in row and row.qual.dtype == tfloat64:
        # -10 qual means missing
        other_cols.append(fx.expr("if(qual = -10, null, qual)").alias("qual"))
    # [] filters means PASS, null filters means missing
    if 'filters' in row and row.filters.dtype == tset(tstr):
        other_cols.append(fx.expr("if(size(filters) = 0, array('PASS'), filters)").alias("filters"))
    # Rename info.* columns to INFO_*
    if 'info' in row and isinstance(row.info.dtype, tstruct):
        for f in row.info:
            other_cols.append(fx.col(f"`info.{f}`").alias(f"INFO_{f}"))

    assert check_return_type(other_cols)
    return other_cols


def _require_row_variant_w_struct_locus(mt: MatrixTable) -> NoReturn:
    """
    Similar to hail.methods.misc.require_row_key_variant_w_struct_locus, but not necessarily as keys
    """
    assert check_argument_types()

    if (not set(['locus', 'alleles']).issubset(set(mt.rows().row)) or
            not mt['alleles'].dtype == tarray(tstr) or
        (not isinstance(mt['locus'].dtype, tlocus) and
         mt['locus'].dtype != hl.dtype('struct{contig: str, position: int32}'))):
        raise ValueError("'hail.from_matrix_table' requires row to contain two fields 'locus'"
                         " (type 'locus<any>' or 'struct{{contig: str, position: int32}}') and "
                         "'alleles' (type 'array<str>')")


[docs]def from_matrix_table(mt: MatrixTable, include_sample_ids: bool = True) -> DataFrame: """ Converts a Hail MatrixTable to a Glow DataFrame. The variant fields are derived from the Hail MatrixTable row fields. The sample IDs are derived from the Hail MatrixTable column fields. All other genotype fields are derived from the Hail MatrixTable entry fields. Requires that the MatrixTable rows contain locus and alleles fields. Args: mt : The Hail MatrixTable to convert include_sample_ids : If true (default), include sample IDs in the Glow DataFrame. Sample names increase the size of each row, both in memory and on storage. Returns: Glow DataFrame converted from the MatrixTable. """ assert check_argument_types() # Ensure that dataset rows contain locus and alleles _require_row_variant_w_struct_locus(mt) glow_compatible_df = mt.localize_entries('entries').to_spark().select( *_get_base_cols(mt.rows().row), *_get_other_cols(mt.rows().row), _get_genotypes_col(mt.entry, _get_sample_ids(mt.col_key, include_sample_ids))) assert check_return_type(glow_compatible_df) return glow_compatible_df