Source code for RNApysoforms.shorten_gaps

import polars as pl
import plotly.graph_objects as go
from typing import List, Union
from RNApysoforms.to_intron import to_intron
from RNApysoforms.utils import check_df
from RNApysoforms.calculate_exon_number import calculate_exon_number


[docs] def shorten_gaps( annotation: pl.DataFrame, transcript_id_column: str = "transcript_id", target_gap_width: int = 100 ) -> pl.DataFrame: """ Shortens intron and transcript start gaps between exons in genomic annotations to enhance visualization. This function processes genomic annotations by shortening the widths of intron gaps and gaps at the start of transcripts to a specified target size, while preserving exon and CDS regions. The goal is to improve the clarity of transcript visualizations by reducing the visual space occupied by long intron regions and aligning transcripts for consistent rescaling, maintaining the relative structure of the transcripts. Parameters ---------- annotation : pl.DataFrame A Polars DataFrame containing genomic annotations, including exons and optionally CDS and intron data. Required columns include: - 'start': Start position of the feature. - 'end': End position of the feature. - 'type': Feature type, expected to include 'exon' and optionally 'intron' and 'CDS'. - 'strand': Strand information ('+' or '-'). - 'seqnames': Chromosome or sequence name. - transcript_id_column: Column used to group transcripts, typically 'transcript_id'. transcript_id_column : str, optional The column used to group transcripts, by default "transcript_id". This identifies individual transcripts within the annotation data. target_gap_width : int, optional The maximum width for intron gaps and transcript start gaps after shortening. Gaps wider than this will be reduced to this size. Default is 100. Returns ------- pl.DataFrame A Polars DataFrame with shortened intron and transcript start gaps and rescaled coordinates for exons, introns, and CDS regions. The DataFrame includes: - Original columns from the input DataFrame. - 'rescaled_start': The rescaled start position after shortening gaps. - 'rescaled_end': The rescaled end position after shortening gaps. Raises ------ TypeError If 'annotation' is not a Polars DataFrame. ValueError If required columns are missing in the input DataFrame. If exons are not from a single chromosome and strand when calculating gaps. If there are no common columns to join on between CDS and exons when processing CDS regions. Examples -------- Shorten intron and transcript start gaps in a genomic annotation DataFrame: >>> import polars as pl >>> from RNApysoforms import shorten_gaps >>> df = pl.DataFrame({ ... "transcript_id": ["tx1", "tx1", "tx1"], ... "start": [100, 200, 500], ... "end": [150, 250, 600], ... "type": ["exon", "exon", "exon"], ... "strand": ["+", "+", "+"], ... "seqnames": ["chr1", "chr1", "chr1"], ... "exon_number": [1, 2, 3] ... }) >>> shortened_df = shorten_gaps(df, transcript_id_column="transcript_id", target_gap_width=50) >>> print(shortened_df.head()) This will return a DataFrame where the intron and transcript start gaps have been shortened to a maximum width of 50, and includes rescaled coordinates for visualization. Notes ----- - The function ensures that exon and CDS regions maintain their original lengths, while intron gaps and transcript start gaps are shortened. - If intron entries are not present in the input DataFrame, the function generates them using the 'to_intron' function. - If 'exon_number' is not present in the input DataFrame, it will be automatically calculated. - The input DataFrame must contain the required columns listed above. - The function processes gaps at the start of transcripts to align transcripts for consistent rescaling. - After shortening gaps, the coordinates are rescaled to maintain the relative positions of features within and across transcripts. - The function returns the rescaled DataFrame with original columns plus 'rescaled_start' and 'rescaled_end'. """ # Check if annotation is a Polars DataFrame if not isinstance(annotation, pl.DataFrame): raise TypeError( f"Expected 'annotation' to be of type pl.DataFrame, got {type(annotation)}." "\nYou can convert a pandas DataFrame to Polars using: polars_df = pl.from_pandas(pandas_df)" ) # Validate the input DataFrame to ensure required columns are present check_df(annotation, ["start", "end", "type", "strand", "seqnames", transcript_id_column]) # Check if there are intron entries in the annotation data if "intron" in annotation["type"].unique().to_list(): check_df(annotation, ["start", "end", "type", "strand", "seqnames", transcript_id_column, "exon_number"]) # Separate intron data if present introns = annotation.filter(pl.col("type") == "intron") else: # Generate intron entries if they are not present annotation = to_intron(annotation=annotation, transcript_id_column=transcript_id_column) introns = annotation.filter(pl.col("type") == "intron") # Separate intron data check_df(annotation, ["start", "end", "type", "strand", "seqnames", transcript_id_column, "exon_number"]) # Check if there are CDS entries in the annotation data if "CDS" in annotation["type"].unique().to_list(): # Separate CDS data if present cds = annotation.filter(pl.col("type") == "CDS") else: cds = None # No CDS entries in the data # Separate exons from the annotation data exons = annotation.filter(pl.col("type") == "exon") # Ensure the 'type' column in exons and introns is set correctly exons = _get_type(exons, "exons") # Mark the type as 'exon' introns = _get_type(introns, "introns") # Mark the type as 'intron' # Identify gaps between exons within the same chromosome and strand gaps = _get_gaps(exons) # Map gaps to introns to identify which gaps correspond to which introns gap_map = _get_gap_map(introns, gaps) # Shorten gaps based on the target gap width introns_shortened = _get_shortened_gaps( introns, gaps, gap_map, transcript_id_column, target_gap_width ) # Handle gaps at the start of transcripts to align them tx_start_gaps = _get_tx_start_gaps(exons, transcript_id_column) # Gaps at the start of transcripts gap_map_tx_start = _get_gap_map(tx_start_gaps, gaps) tx_start_gaps_shortened = _get_shortened_gaps( tx_start_gaps, gaps, gap_map_tx_start, transcript_id_column, target_gap_width ) tx_start_gaps_shortened = tx_start_gaps_shortened.drop(['start', 'end', 'strand', 'seqnames']) # Rescale the coordinates of exons and introns after shortening the gaps rescaled_tx = _get_rescaled_txs( exons, introns_shortened, tx_start_gaps_shortened, transcript_id_column ) # Process CDS regions if available if isinstance(cds, pl.DataFrame): # Calculate differences between exons and CDS regions cds_diff = _get_cds_exon_difference(exons, cds, transcript_id_column) # Rescale CDS regions based on the rescaled exons rescaled_cds = _get_rescale_cds(cds_diff, rescaled_tx.filter(pl.col("type") == "exon"), transcript_id_column) ## Prepare data for concatenation final_columns = annotation.columns + ["rescaled_start", "rescaled_end"] rescaled_cds = rescaled_cds[final_columns] rescaled_tx = rescaled_tx[final_columns] # Combine the rescaled CDS data into the final DataFrame rescaled_tx = pl.concat([rescaled_tx, rescaled_cds]) # Sort the DataFrame by start and end positions rescaled_tx = rescaled_tx.sort(by=['start', 'end']) # Return transcripts in original order they were given original_order = annotation[transcript_id_column].unique(maintain_order=True).to_list() order_mapping = {transcript: index for index, transcript in enumerate(original_order)} rescaled_tx = (rescaled_tx .with_columns(pl.col(transcript_id_column).replace(order_mapping).alias("order")) .sort("order") .drop("order")) # Include original columns and rescaled coordinates in the final DataFrame final_columns = annotation.columns + ["rescaled_start", "rescaled_end"] rescaled_tx = rescaled_tx[final_columns].clone() return rescaled_tx # Return the rescaled transcript DataFrame
def _get_type(df: pl.DataFrame, df_type: str) -> pl.DataFrame: """ Ensures that the 'type' column in the DataFrame is correctly set to 'exon' or 'intron'. Parameters ---------- df : pl.DataFrame A Polars DataFrame containing genomic features. df_type : str The type to set in the 'type' column, either 'exons' or 'introns'. Returns ------- pl.DataFrame The input DataFrame with the 'type' column set to 'exon' or 'intron'. Raises ------ ValueError If 'df_type' is not 'exons' or 'introns'. Notes ----- - If the 'type' column does not exist in the input DataFrame, it is added with the specified 'df_type'. - If 'df_type' is 'introns', the function filters the DataFrame to include only intron entries. """ # Validate 'df_type' parameter if df_type not in ["exons", "introns"]: raise ValueError("df_type must be either 'exons' or 'introns'") # Add or set the 'type' column if 'type' not in df.schema: # If 'type' column is missing, add it with the appropriate value return df.with_columns( pl.lit('exon' if df_type == 'exons' else 'intron').alias('type') ) elif df_type == 'introns': # If 'df_type' is 'introns', ensure only intron entries are included df = df.filter(pl.col('type') == 'intron') return df def _get_gaps(exons: pl.DataFrame) -> pl.DataFrame: """ Identifies gaps between exons within the same chromosome and strand. Parameters ---------- exons : pl.DataFrame A Polars DataFrame containing exon information with 'seqnames', 'start', 'end', and 'strand'. Returns ------- pl.DataFrame A DataFrame with 'start' and 'end' positions of gaps between exons. Raises ------ ValueError If exons are not from a single chromosome and strand. Notes ----- - All exons must be from the same chromosome and strand to accurately identify gaps. - The function merges overlapping exons and computes the gaps between them. """ # Ensure all exons are from a single chromosome and strand seqnames_unique = exons["seqnames"].n_unique() strand_unique = exons["strand"].n_unique() if seqnames_unique != 1 or strand_unique != 1: raise ValueError("Exons must be from a single chromosome and strand") # Sort exons by start position exons_sorted = exons.sort('start') # Compute cumulative maximum of 'end' shifted by 1 to identify overlaps exons_with_cummax = exons_sorted.with_columns([ pl.col('end').cum_max().shift(1).fill_null(0).alias('cummax_end') ]) # Determine if a new group starts (i.e., no overlap with previous exons) exons_with_cummax = exons_with_cummax.with_columns([ (pl.col('start') > pl.col('cummax_end')).alias('is_new_group') ]) # Compute group_id as cumulative sum of 'is_new_group' exons_with_cummax = exons_with_cummax.with_columns([ pl.col('is_new_group').cast(pl.Int64).cum_sum().alias('group_id') ]) # Merge exons within each group to identify continuous blocks merged_exons = exons_with_cummax.group_by('group_id').agg([ pl.col('start').min().alias('start'), pl.col('end').max().alias('end') ]) # Sort merged exons by 'start' merged_exons = merged_exons.sort('start') # Compute 'prev_end' as the shifted 'end' to identify gaps merged_exons = merged_exons.with_columns([ pl.col('end').shift(1).alias('prev_end') ]) # Compute gap start and end positions merged_exons = merged_exons.with_columns([ (pl.col('prev_end') + 1).alias('gap_start'), (pl.col('start') - 1).alias('gap_end') ]) # Filter valid gaps where 'gap_start' is less than or equal to 'gap_end' gaps = merged_exons.filter(pl.col('gap_start') <= pl.col('gap_end')).select([ pl.col('gap_start').alias('start'), pl.col('gap_end').alias('end') ]) return gaps # Return the DataFrame containing gap positions def _get_tx_start_gaps(exons: pl.DataFrame, transcript_id_column: str) -> pl.DataFrame: """ Identifies gaps at the start of each transcript based on the first exon. Parameters ---------- exons : pl.DataFrame A Polars DataFrame containing exon information. transcript_id_column : str Column used to group transcripts (e.g., 'transcript_id'). Returns ------- pl.DataFrame A DataFrame containing gaps at the start of each transcript. Notes ----- - The function calculates the gap between the overall start of the first exon across all transcripts and the start of each individual transcript's first exon. - It assumes that all exons are on the same chromosome and strand. """ # Get the start of the first exon for each transcript tx_starts = exons.group_by(transcript_id_column).agg(pl.col('start').min()) # Get the overall start of the first exon across all transcripts overall_start = exons['start'].min() # Use the same chromosome and strand for all transcripts seqnames_value = exons['seqnames'][0] strand_value = exons['strand'][0] # Create DataFrame with gaps at the start of transcripts tx_start_gaps = tx_starts.with_columns([ pl.col('start').cast(pl.Int64).alias('end'), pl.lit(overall_start).cast(pl.Int64).alias('start'), pl.lit(seqnames_value).alias('seqnames'), pl.lit(strand_value).alias('strand'), ]) return tx_start_gaps # Return the DataFrame with transcript start gaps def _get_gap_map(df: pl.DataFrame, gaps: pl.DataFrame) -> dict: """ Maps gaps to the corresponding exons or introns based on their positions. Parameters ---------- df : pl.DataFrame A DataFrame containing exons or introns, with 'start' and 'end' positions. gaps : pl.DataFrame A DataFrame containing gaps between exons, with 'start' and 'end' positions. Returns ------- dict A dictionary containing mappings: - 'equal': DataFrame of gaps that exactly match the 'start' and 'end' of exons/introns. - 'pure_within': DataFrame of gaps that are fully within exons/introns but do not exactly match. Notes ----- - The function adds row indices to both df and gaps for mapping. - It first identifies exact matches, then finds gaps fully within exons/introns. """ # Add an index to each gap and exon/intron row gaps = gaps.with_row_index("gap_index") df = df.with_row_index("df_index") # Find gaps where the start and end positions exactly match those of df equal_hits = gaps.join(df, how="inner", left_on=["start", "end"], right_on=["start", "end"]).select([ pl.col("gap_index"), pl.col("df_index") ]) # Rename columns for clarity when performing the cross join gaps = gaps.rename({ "start": "gaps.start", "end": "gaps.end" }) df = df.rename({ "start": "df.start", "end": "df.end" }) # Find gaps that are fully contained within exons/introns within_hits = gaps.join(df, how="cross").filter( (pl.col("gaps.start") >= pl.col("df.start")) & (pl.col("gaps.end") <= pl.col("df.end")) ).select([pl.col("gap_index"), pl.col("df_index")]) # Remove within_hits that also appear in equal_hits pure_within_hits = within_hits.join(equal_hits, how="anti", on=["df_index", "gap_index"]) # Sort the equal_hits by gap and df index equal_hits = equal_hits.sort(["gap_index", "df_index"]) # Return the mappings return { 'equal': equal_hits, 'pure_within': pure_within_hits } def _get_shortened_gaps(df: pl.DataFrame, gaps: pl.DataFrame, gap_map: dict, transcript_id_column: str, target_gap_width: int) -> pl.DataFrame: """ Shortens the gaps between exons or introns based on a target gap width. Parameters ---------- df : pl.DataFrame A DataFrame containing exons or introns. gaps : pl.DataFrame A DataFrame containing gaps between exons. gap_map : dict A dictionary mapping gaps to their corresponding exons or introns. transcript_id_column : str Column used to group transcripts (e.g., 'transcript_id'). target_gap_width : int The maximum allowed width for the gaps. Returns ------- pl.DataFrame A DataFrame with shortened gaps and adjusted positions. Notes ----- - Gaps classified as 'equal' and exceeding the target width are shortened to match the target. - Gaps classified as 'pure_within' are adjusted based on the target width, ensuring they do not exceed the defined maximum. - The function updates the 'width' of each gap accordingly and removes unnecessary columns post-adjustment. """ # Calculate the width of exons/introns and initialize a 'shorten_type' column df = df.with_columns( (pl.col('end') - pl.col('start') + 1).alias('width'), # Calculate the width pl.lit('none').alias('shorten_type') # Initialize shorten_type column ) # Add an index column to the df DataFrame df = df.with_row_index(name="df_index") # Update 'shorten_type' for gaps that exactly match exons/introns if 'equal' in gap_map and 'df_index' in gap_map['equal'].columns: df = df.with_columns( pl.when(pl.col("df_index").is_in(gap_map["equal"]["df_index"].to_list())) .then(pl.lit("equal")) .otherwise(pl.col("shorten_type")) .alias("shorten_type") ) # Update 'shorten_type' for gaps fully within exons/introns if 'pure_within' in gap_map and 'df_index' in gap_map['pure_within'].columns: df = df.with_columns( pl.when(pl.col("df_index").is_in(gap_map['pure_within']['df_index'].to_list())) .then(pl.lit("pure_within")) .otherwise(pl.col("shorten_type")) .alias("shorten_type") ) # Shorten gaps that are of type 'equal' and have a width greater than the target_gap_width df = df.with_columns( pl.when((pl.col('shorten_type') == 'equal') & (pl.col('width') > target_gap_width)) .then(pl.lit(target_gap_width)) .otherwise(pl.col('width')) .alias('shortened_width') ) # Handle gaps that are 'pure_within' if 'pure_within' in gap_map and len(gap_map['pure_within']) > 0: overlapping_gap_indexes = gap_map['pure_within']['gap_index'] gaps = gaps.with_row_index(name="gap_index") if len(overlapping_gap_indexes) > 0: # Calculate the width of overlapping gaps overlapping_gaps = gaps.filter(pl.col("gap_index").is_in(overlapping_gap_indexes.to_list())) overlapping_gaps = overlapping_gaps.with_columns( (pl.col('end') - pl.col('start') + 1).alias('gap_width') ) # Shorten gap width if larger than target_gap_width overlapping_gaps = overlapping_gaps.with_columns( pl.when(pl.col('gap_width') > target_gap_width) .then(pl.lit(target_gap_width)) .otherwise(pl.col('gap_width')) .alias('shortened_gap_width') ) # Calculate the gap difference overlapping_gaps = overlapping_gaps.with_columns( (pl.col('gap_width') - pl.col('shortened_gap_width')).alias('shortened_gap_diff') ) # Map the gap differences back to df gap_diff_df = gap_map['pure_within'].join( overlapping_gaps.select('gap_index', 'shortened_gap_diff'), on='gap_index', how='left' ) # Aggregate gap differences by df indexes sum_gap_diff = gap_diff_df.group_by('df_index').agg( pl.sum('shortened_gap_diff').alias('sum_shortened_gap_diff') ) # Join the calculated gap differences with the df DataFrame df = df.join(sum_gap_diff, on='df_index', how='left') # Adjust the width based on gap differences df = df.with_columns( pl.when(pl.col('sum_shortened_gap_diff').is_null()) .then(pl.col('shortened_width')) .otherwise(pl.col('width') - pl.col('sum_shortened_gap_diff')) .alias('shortened_width') ) # Clean up unnecessary columns df = df.drop('sum_shortened_gap_diff') df = df.drop(['shorten_type', 'width', 'df_index']) df = df.rename({'shortened_width': 'width'}) return df # Return the DataFrame with shortened gaps def _get_rescaled_txs( exons: pl.DataFrame, introns_shortened: pl.DataFrame, tx_start_gaps_shortened: pl.DataFrame, transcript_id_column: str ) -> pl.DataFrame: """ Rescales transcript coordinates based on shortened gaps for exons and introns. Parameters ---------- exons : pl.DataFrame DataFrame containing exon information. introns_shortened : pl.DataFrame DataFrame containing intron information with shortened gaps. tx_start_gaps_shortened : pl.DataFrame DataFrame containing rescaled transcript start gaps. transcript_id_column : str Column used to group transcripts (e.g., 'transcript_id'). Returns ------- pl.DataFrame Rescaled transcript DataFrame with adjusted coordinates. Notes ----- - The function concatenates exons and shortened introns, sorts them, and calculates rescaled start and end positions. - It adjusts intron positions to prevent overlap with exons. - Transcript start gaps are incorporated to ensure accurate rescaling across different transcripts. """ # Clone exons to avoid altering the original DataFrame exons = exons.clone() # Define columns to keep for introns, including 'width' column_to_keep = exons.columns + ["width"] # Select and reorder columns for the shortened introns introns_shortened = introns_shortened.select(column_to_keep) # Add a new 'width' column to exons representing their lengths exons = exons.with_columns( (pl.col('end') - pl.col('start') + 1).alias('width') ) # Concatenate exons and shortened introns into a single DataFrame rescaled_tx = pl.concat([exons, introns_shortened], how='vertical') # Sort based on transcript_id, start, and end rescaled_tx = rescaled_tx.sort([transcript_id_column, 'start', 'end']) # Calculate cumulative sum for rescaled end positions rescaled_tx = rescaled_tx.with_columns( pl.col('width').cum_sum().over(transcript_id_column).alias('rescaled_end') ) # Compute the rescaled start positions based on the cumulative end positions rescaled_tx = rescaled_tx.with_columns( (pl.col('rescaled_end') - pl.col('width') + 1).alias('rescaled_start') ) # Join rescaled transcript start gaps to adjust start positions rescaled_tx = rescaled_tx.join( tx_start_gaps_shortened, on=transcript_id_column, how='left', suffix='_tx_start' ) # Adjust the rescaled start and end positions based on transcript start gaps rescaled_tx = rescaled_tx.with_columns([ (pl.col('rescaled_end') + pl.col('width_tx_start')).alias('rescaled_end'), (pl.col('rescaled_start') + pl.col('width_tx_start')).alias('rescaled_start') ]) # Drop 'width' column as it's no longer needed rescaled_tx = rescaled_tx.drop(['width']) # Reorder columns for consistency in the output columns = rescaled_tx.columns column_order = ['seqnames', 'start', 'end', "rescaled_start", "rescaled_end", 'strand'] + [ col for col in columns if col not in ['seqnames', 'start', 'end', "rescaled_start", "rescaled_end", 'strand'] ] rescaled_tx = rescaled_tx.select(column_order) return rescaled_tx # Return the rescaled transcript coordinates def _get_cds_exon_difference(gene_exons: pl.DataFrame, gene_cds_regions: pl.DataFrame, transcript_id_column: str) -> pl.DataFrame: """ Calculates the absolute differences between the start and end positions of exons and CDS regions. Parameters ---------- gene_exons : pl.DataFrame DataFrame containing exon regions. gene_cds_regions : pl.DataFrame DataFrame containing CDS (Coding DNA Sequence) regions. transcript_id_column : str The column name that identifies transcript groups within the DataFrame. Returns ------- pl.DataFrame DataFrame with the absolute differences between exon and CDS start/end positions. Raises ------ ValueError If the required columns 'exon_number' and transcript_id_column are missing from either DataFrame. Notes ----- - The function joins CDS and exon DataFrames on transcript_id_column and 'exon_number' to align corresponding regions. - It calculates the absolute differences between exon and CDS start and end positions to identify discrepancies. """ # Rename 'start' and 'end' columns in CDS regions for clarity cds_regions = gene_cds_regions.rename({'start': 'cds_start', 'end': 'cds_end'}) # Remove the 'type' column if it exists in CDS if 'type' in cds_regions.columns: cds_regions = cds_regions.drop('type') # Rename 'start' and 'end' columns in exon regions for clarity exons = gene_exons.rename({'start': 'exon_start', 'end': 'exon_end'}) # Remove the 'type' column if it exists in exons if 'type' in exons.columns: exons = exons.drop('type') ## Define required columns required_columns = [transcript_id_column, "exon_number"] # Identify common columns to join CDS and exons on (e.g., transcript_id) if not all(col in cds_regions.columns for col in required_columns) or not all(col in exons.columns for col in required_columns): raise ValueError("Missing necessary 'exon_number' and/or '" + transcript_id_column + "' columns needed to join CDS and exons.") # Perform left join between CDS and exon data on the common columns cds_exon_diff = cds_regions.join(exons[[transcript_id_column, "exon_number", "exon_start", "exon_end"]], on=required_columns, how='left') # Calculate absolute differences between exon and CDS start positions cds_exon_diff = cds_exon_diff.with_columns( (pl.col('exon_start') - pl.col('cds_start')).abs().alias('diff_start') ) # Calculate absolute differences between exon and CDS end positions cds_exon_diff = cds_exon_diff.with_columns( (pl.col('exon_end') - pl.col('cds_end')).abs().alias('diff_end') ) return cds_exon_diff # Return the DataFrame with differences def _get_rescale_cds(cds_exon_diff: pl.DataFrame, gene_rescaled_exons: pl.DataFrame, transcript_id_column: str) -> pl.DataFrame: """ Rescales CDS regions based on exon positions and the calculated differences between them. Parameters ---------- cds_exon_diff : pl.DataFrame DataFrame with differences between exon and CDS start/end positions. gene_rescaled_exons : pl.DataFrame DataFrame containing rescaled exon positions. transcript_id_column : str The column name that identifies transcript groups within the DataFrame. Returns ------- pl.DataFrame Rescaled CDS positions based on exon positions. Raises ------ ValueError If the required columns 'exon_number' and transcript_id_column are missing from either DataFrame. Notes ----- - The function joins CDS differences and rescaled exons on transcript_id_column and 'exon_number'. - It adjusts CDS start and end positions based on the rescaled exon positions and the previously calculated differences. - It ensures that CDS regions are accurately positioned relative to exons after rescaling. """ # Assign a 'type' column with the value "CDS" and drop unnecessary columns columns_to_drop = ['exon_start', 'exon_end'] cds_prepared = ( cds_exon_diff .with_columns(pl.lit("CDS").alias("type")) .drop([col for col in columns_to_drop if col in cds_exon_diff.columns]) ) # Rename columns in rescaled exons for consistency exons_prepared = gene_rescaled_exons.rename({'rescaled_start': 'exon_start', 'rescaled_end': 'exon_end'}) exons_prepared = exons_prepared.drop(["start", "end"]) # Drop 'type' column if present if 'type' in exons_prepared.columns: exons_prepared = exons_prepared.drop('type') ## Define required columns required_columns = [transcript_id_column, "exon_number"] # Identify common columns to join CDS and exons on (e.g., transcript_id) if not all(col in cds_prepared.columns for col in required_columns) or not all(col in exons_prepared.columns for col in required_columns): raise ValueError("Missing necessary 'exon_number' and '" + transcript_id_column + "' columns needed to join CDS and exons.") # Perform left join on common columns gene_rescaled_cds = cds_prepared.join(exons_prepared[[transcript_id_column, "exon_number", "exon_start", "exon_end"]], on=required_columns, how='left') # Adjust start and end positions of CDS based on exon positions gene_rescaled_cds = gene_rescaled_cds.with_columns([ (pl.col('exon_start') + pl.col('diff_start')).alias('rescaled_start'), (pl.col('exon_end') - pl.col('diff_end')).alias('rescaled_end') ]) # Drop unnecessary columns used for the difference calculations gene_rescaled_cds = gene_rescaled_cds.drop(['exon_start', 'exon_end', 'diff_start', 'diff_end']) # Rename CDS start and end to 'start' and 'end' gene_rescaled_cds = gene_rescaled_cds.rename({ "cds_start": "start", "cds_end": "end" }) return gene_rescaled_cds # Return the rescaled CDS DataFrame