hoodini.utils.validation
Input validation and ID parsing utilities.
1"""Input validation and ID parsing utilities.""" 2 3from __future__ import annotations 4 5import csv 6import re 7from pathlib import Path 8 9import polars as pl 10import rich_click as click 11 12from hoodini.utils.id_parsing import categorize_id 13from hoodini.utils.logging_utils import warn 14 15 16def validate_input_file(ctx, param, value): 17 """Validate the input file based on its type, either single-column or TSV format.""" 18 if value is None: 19 return None 20 21 if not Path(value).is_file(): 22 # Allow literal IDs/FASTA strings; runner will handle them. 23 return value 24 25 try: 26 with open(value) as file: 27 if param.name == "input_path": 28 lines = [line.strip() for line in file.readlines() if line.strip()] 29 if len(lines) <= 1: 30 raise click.BadParameter("File must contain multiple lines.") 31 32 for line in lines: 33 if "," in line or "\t" in line: 34 raise click.BadParameter( 35 "File must be a single-column text file without delimiters like commas or tabs." 36 ) 37 38 special_char_pattern = re.compile(r"[^A-Za-z0-9\s._]+") 39 for i, line in enumerate(lines): 40 match = special_char_pattern.search(line) 41 if match: 42 raise click.BadParameter( 43 f"Invalid character '{match.group()}' found in line {i+1}: \"{line}\"" 44 ) 45 46 elif param.name == "inputsheet": 47 first_line = file.readline() 48 delimiter = "\t" if "\t" in first_line else None 49 if delimiter is None: 50 raise click.BadParameter("The file is not in TSV format.") 51 file.seek(0) 52 reader = csv.DictReader(file, delimiter=delimiter) 53 required_columns = [ 54 "nucleotide_id", 55 "protein_id", 56 "gff_path", 57 "fna_path", 58 "faa_path", 59 ] 60 if not all(col in reader.fieldnames for col in required_columns): 61 raise click.BadParameter( 62 "TSV file must contain columns: nucleotide_id, protein_id, gff_path, fna_path, faa_path" 63 ) 64 65 found_valid_row = False 66 for row in reader: 67 if row["nucleotide_id"].strip() or row["protein_id"].strip(): 68 found_valid_row = True 69 break 70 if not found_valid_row: 71 raise click.BadParameter( 72 "The TSV file must contain at least one valid row with required data." 73 ) 74 75 except Exception as e: 76 raise click.BadParameter(f"Error reading file: {e}") 77 78 return value 79 80 81def validate_domains(ctx, param, value): 82 """Validate domain database names against MetaCerberus availability.""" 83 if not value or (isinstance(value, str) and value.strip() == ""): 84 return None 85 86 db_names = [d.strip().lower() for d in value.split(",") if d.strip()] 87 if not db_names: 88 raise click.BadParameter("Domain parameter must contain at least one database name.") 89 90 for db in db_names: 91 if not re.match(r"^[a-zA-Z0-9_-]+$", db): 92 raise click.BadParameter( 93 f"Invalid database name '{db}'. Only letters, numbers, underscores, and hyphens are allowed." 94 ) 95 96 try: 97 from hoodini.download.metacerberus import check_downloaded, get_db_groups, list_db_files 98 99 files_list = list_db_files() 100 groups = get_db_groups(files_list) 101 status = check_downloaded(groups) 102 103 valid_dbs = [] 104 invalid_dbs = [] 105 missing_files_dbs = [] 106 107 for db in db_names: 108 if db not in groups: 109 invalid_dbs.append(db) 110 continue 111 112 file_statuses = status.get(db, []) 113 hmm_present = any( 114 present for f, present in file_statuses if f["name"].endswith(".hmm.gz") 115 ) 116 tsv_present = any(present for f, present in file_statuses if f["name"].endswith(".tsv")) 117 118 has_hmm_file = any(f["name"].endswith(".hmm.gz") for f, _ in file_statuses) 119 has_tsv_file = any(f["name"].endswith(".tsv") for f, _ in file_statuses) 120 121 if has_tsv_file and not tsv_present: 122 missing_files_dbs.append(f"{db} (missing TSV)") 123 elif has_hmm_file and not hmm_present: 124 missing_files_dbs.append(f"{db} (missing HMM)") 125 elif not has_hmm_file and not has_tsv_file: 126 invalid_dbs.append(db) 127 elif (has_hmm_file and hmm_present and has_tsv_file and tsv_present) or ( 128 not has_hmm_file and has_tsv_file and tsv_present 129 ): 130 valid_dbs.append(db) 131 else: 132 missing_files_dbs.append(db) 133 134 if invalid_dbs: 135 raise click.BadParameter( 136 f"Unknown MetaCerberus databases: {', '.join(invalid_dbs)}. Run 'hoodini download metacerberus' to see available databases." 137 ) 138 139 if missing_files_dbs: 140 raise click.BadParameter( 141 f"MetaCerberus databases not downloaded: {', '.join(missing_files_dbs)}. Run 'hoodini download metacerberus {','.join(db_names)}' to download them." 142 ) 143 144 if not valid_dbs: 145 raise click.BadParameter("No valid MetaCerberus databases found.") 146 147 return valid_dbs 148 149 except ImportError as e: 150 warn(f"Could not validate MetaCerberus databases: {e}") 151 return value 152 except Exception as e: 153 warn(f"Could not validate MetaCerberus databases: {e}") 154 return value 155 156 157def switch_assembly_prefix(asm_id): 158 if not isinstance(asm_id, str): 159 return asm_id 160 if asm_id.startswith("GCA_"): 161 return "GCF_" + asm_id[4:] 162 if asm_id.startswith("GCF_"): 163 return "GCA_" + asm_id[4:] 164 return asm_id 165 166 167def is_refseq_nuccore(nuc_id): 168 """Return True if the nuccore accession is a RefSeq accession else False.""" 169 refseq_prefixes = ("NC_", "NZ_", "NM_", "NR_", "XM_", "XR_", "AP_", "YP_", "XP_", "WP_") 170 return isinstance(nuc_id, str) and nuc_id.startswith(refseq_prefixes) 171 172 173def read_input_sheet(filename): 174 df = pl.read_csv(filename, separator="\t", dtype=str) 175 df = df.with_row_count("og_index").with_columns(pl.col("og_index").cast(pl.Utf8)) 176 177 expected_columns = [ 178 "protein_id", 179 "nucleotide_id", 180 "uniprot_id", 181 "gff_path", 182 "faa_path", 183 "fna_path", 184 "gbf_path", 185 "taxid", 186 "assembly_id", 187 "failed", 188 "input_type", 189 "premade", 190 ] 191 for col in expected_columns: 192 if col not in df.columns: 193 df = df.with_columns(pl.lit(None).alias(col)) 194 return df 195 196 197def read_input_list(filename): 198 input_list = Path(filename).read_text().splitlines() 199 data = [] 200 for index, id_ in enumerate(input_list): 201 if not id_ or id_.strip() == "": 202 continue 203 204 category = categorize_id(id_) 205 record = { 206 "og_index": index, 207 "protein_id": None, 208 "nucleotide_id": None, 209 "uniprot_id": None, 210 "failed": None, 211 "failed_reason": None, 212 "gff_path": None, 213 "faa_path": None, 214 "fna_path": None, 215 "strand": None, 216 "start": None, 217 "end": None, 218 "gbf_path": None, 219 "taxid": None, 220 "assembly_id": None, 221 "input_type": None, 222 "premade": None, 223 } 224 225 if category["type"] == "protein": 226 record["protein_id"] = category["id"] 227 record["input_type"] = "protein" 228 elif category["type"] == "nucleotide": 229 record["nucleotide_id"] = category["id"] 230 record["protein_id"] = category.get("protein_id") 231 record["input_type"] = "nucleotide" 232 elif category["type"] == "uniprot": 233 record["uniprot_id"] = category["id"] 234 record["input_type"] = "protein" 235 elif category["type"] == "unmatched": 236 record["failed"] = True 237 record["failed_reason"] = "not valid ID" 238 else: 239 if ":" in id_: 240 category = categorize_id(id_.split(":")[0]) 241 if category["type"] == "nucleotide": 242 if ( 243 "-" in id_.split(":")[1] 244 and id_.split(":")[1].split("-")[0].isdigit() 245 and id_.split(":")[1].split("-")[1].isdigit() 246 ): 247 record["nucleotide_id"] = id_.split(":")[0] 248 record["start"] = id_.split(":")[1].split("-")[0] 249 record["end"] = id_.split(":")[1].split("-")[1] 250 if record["start"] > record["end"]: 251 record["start"], record["end"] = record["end"], record["start"] 252 record["strand"] = "-" 253 else: 254 record["strand"] = "+" 255 record["input_type"] = "nucleotide" 256 else: 257 record["failed"] = True 258 record["failed_reason"] = "not valid ID" 259 record["nucleotide_id"] = category["id"] 260 record["input_type"] = "nucleotide" 261 262 if record["protein_id"] is None and category.get("id"): 263 record["protein_id"] = category["id"] 264 record["input_type"] = record.get("input_type") or "protein" 265 data.append(record) 266 df = pl.DataFrame(data, infer_schema_length=len(data)) 267 return df 268 269 270def uniprot2ncbi(df: pl.DataFrame) -> pl.DataFrame: 271 """ 272 Map UniProt accessions to NCBI protein IDs using UniProtMapper. 273 274 - Only attempts mapping for rows with `uniprot_id` present, `protein_id` null/empty, 275 and no local files (`gff_path`/`faa_path` missing). 276 - If mapping returns multiple hits, merged `To` values are applied directly. 277 - On failure, sets the `failed` column with a descriptive message. 278 """ 279 required_cols = {"uniprot_id", "protein_id", "gff_path", "faa_path"} 280 if not required_cols.issubset(set(df.columns)): 281 return df 282 283 # Cast protein_id to Utf8 to handle Null-type columns safely 284 protein_id_str = df["protein_id"].cast(pl.Utf8, strict=False) 285 mask = ( 286 df["uniprot_id"].is_not_null() 287 & (protein_id_str.is_null() | (protein_id_str == "")) 288 & df["gff_path"].is_null() 289 & df["faa_path"].is_null() 290 ) 291 292 if mask.sum() == 0: 293 return df 294 295 from UniProtMapper import ProtMapper 296 297 df_pd = df.to_pandas() 298 to_map = df_pd.loc[mask.to_pandas(), "uniprot_id"].dropna().unique().tolist() 299 if not to_map: 300 return df 301 302 mapper = ProtMapper() 303 try: 304 mapped_df, failed_ids = mapper.get( 305 ids=to_map, from_db="UniProtKB_AC-ID", to_db="EMBL-GenBank-DDBJ_CDS" 306 ) 307 except Exception: 308 failed_ids = to_map 309 mapped_df = None 310 311 if mapped_df is not None and not mapped_df.empty: 312 df_pd = df_pd.merge( 313 mapped_df[["From", "To"]], left_on="uniprot_id", right_on="From", how="left" 314 ) 315 df_pd["protein_id"] = df_pd["protein_id"].fillna(df_pd["To"]) 316 df_pd = df_pd.drop(columns=["From", "To"]) 317 318 if failed_ids: 319 failed_mask = df_pd["uniprot_id"].isin(failed_ids) 320 df_pd.loc[failed_mask, "failed"] = True 321 df_pd.loc[failed_mask, "failed_reason"] = "No associated NCBI found for the UniProt entry." 322 323 return pl.from_pandas(df_pd)
def
validate_input_file(ctx, param, value):
17def validate_input_file(ctx, param, value): 18 """Validate the input file based on its type, either single-column or TSV format.""" 19 if value is None: 20 return None 21 22 if not Path(value).is_file(): 23 # Allow literal IDs/FASTA strings; runner will handle them. 24 return value 25 26 try: 27 with open(value) as file: 28 if param.name == "input_path": 29 lines = [line.strip() for line in file.readlines() if line.strip()] 30 if len(lines) <= 1: 31 raise click.BadParameter("File must contain multiple lines.") 32 33 for line in lines: 34 if "," in line or "\t" in line: 35 raise click.BadParameter( 36 "File must be a single-column text file without delimiters like commas or tabs." 37 ) 38 39 special_char_pattern = re.compile(r"[^A-Za-z0-9\s._]+") 40 for i, line in enumerate(lines): 41 match = special_char_pattern.search(line) 42 if match: 43 raise click.BadParameter( 44 f"Invalid character '{match.group()}' found in line {i+1}: \"{line}\"" 45 ) 46 47 elif param.name == "inputsheet": 48 first_line = file.readline() 49 delimiter = "\t" if "\t" in first_line else None 50 if delimiter is None: 51 raise click.BadParameter("The file is not in TSV format.") 52 file.seek(0) 53 reader = csv.DictReader(file, delimiter=delimiter) 54 required_columns = [ 55 "nucleotide_id", 56 "protein_id", 57 "gff_path", 58 "fna_path", 59 "faa_path", 60 ] 61 if not all(col in reader.fieldnames for col in required_columns): 62 raise click.BadParameter( 63 "TSV file must contain columns: nucleotide_id, protein_id, gff_path, fna_path, faa_path" 64 ) 65 66 found_valid_row = False 67 for row in reader: 68 if row["nucleotide_id"].strip() or row["protein_id"].strip(): 69 found_valid_row = True 70 break 71 if not found_valid_row: 72 raise click.BadParameter( 73 "The TSV file must contain at least one valid row with required data." 74 ) 75 76 except Exception as e: 77 raise click.BadParameter(f"Error reading file: {e}") 78 79 return value
Validate the input file based on its type, either single-column or TSV format.
def
validate_domains(ctx, param, value):
82def validate_domains(ctx, param, value): 83 """Validate domain database names against MetaCerberus availability.""" 84 if not value or (isinstance(value, str) and value.strip() == ""): 85 return None 86 87 db_names = [d.strip().lower() for d in value.split(",") if d.strip()] 88 if not db_names: 89 raise click.BadParameter("Domain parameter must contain at least one database name.") 90 91 for db in db_names: 92 if not re.match(r"^[a-zA-Z0-9_-]+$", db): 93 raise click.BadParameter( 94 f"Invalid database name '{db}'. Only letters, numbers, underscores, and hyphens are allowed." 95 ) 96 97 try: 98 from hoodini.download.metacerberus import check_downloaded, get_db_groups, list_db_files 99 100 files_list = list_db_files() 101 groups = get_db_groups(files_list) 102 status = check_downloaded(groups) 103 104 valid_dbs = [] 105 invalid_dbs = [] 106 missing_files_dbs = [] 107 108 for db in db_names: 109 if db not in groups: 110 invalid_dbs.append(db) 111 continue 112 113 file_statuses = status.get(db, []) 114 hmm_present = any( 115 present for f, present in file_statuses if f["name"].endswith(".hmm.gz") 116 ) 117 tsv_present = any(present for f, present in file_statuses if f["name"].endswith(".tsv")) 118 119 has_hmm_file = any(f["name"].endswith(".hmm.gz") for f, _ in file_statuses) 120 has_tsv_file = any(f["name"].endswith(".tsv") for f, _ in file_statuses) 121 122 if has_tsv_file and not tsv_present: 123 missing_files_dbs.append(f"{db} (missing TSV)") 124 elif has_hmm_file and not hmm_present: 125 missing_files_dbs.append(f"{db} (missing HMM)") 126 elif not has_hmm_file and not has_tsv_file: 127 invalid_dbs.append(db) 128 elif (has_hmm_file and hmm_present and has_tsv_file and tsv_present) or ( 129 not has_hmm_file and has_tsv_file and tsv_present 130 ): 131 valid_dbs.append(db) 132 else: 133 missing_files_dbs.append(db) 134 135 if invalid_dbs: 136 raise click.BadParameter( 137 f"Unknown MetaCerberus databases: {', '.join(invalid_dbs)}. Run 'hoodini download metacerberus' to see available databases." 138 ) 139 140 if missing_files_dbs: 141 raise click.BadParameter( 142 f"MetaCerberus databases not downloaded: {', '.join(missing_files_dbs)}. Run 'hoodini download metacerberus {','.join(db_names)}' to download them." 143 ) 144 145 if not valid_dbs: 146 raise click.BadParameter("No valid MetaCerberus databases found.") 147 148 return valid_dbs 149 150 except ImportError as e: 151 warn(f"Could not validate MetaCerberus databases: {e}") 152 return value 153 except Exception as e: 154 warn(f"Could not validate MetaCerberus databases: {e}") 155 return value
Validate domain database names against MetaCerberus availability.
def
switch_assembly_prefix(asm_id):
def
is_refseq_nuccore(nuc_id):
168def is_refseq_nuccore(nuc_id): 169 """Return True if the nuccore accession is a RefSeq accession else False.""" 170 refseq_prefixes = ("NC_", "NZ_", "NM_", "NR_", "XM_", "XR_", "AP_", "YP_", "XP_", "WP_") 171 return isinstance(nuc_id, str) and nuc_id.startswith(refseq_prefixes)
Return True if the nuccore accession is a RefSeq accession else False.
def
read_input_sheet(filename):
174def read_input_sheet(filename): 175 df = pl.read_csv(filename, separator="\t", dtype=str) 176 df = df.with_row_count("og_index").with_columns(pl.col("og_index").cast(pl.Utf8)) 177 178 expected_columns = [ 179 "protein_id", 180 "nucleotide_id", 181 "uniprot_id", 182 "gff_path", 183 "faa_path", 184 "fna_path", 185 "gbf_path", 186 "taxid", 187 "assembly_id", 188 "failed", 189 "input_type", 190 "premade", 191 ] 192 for col in expected_columns: 193 if col not in df.columns: 194 df = df.with_columns(pl.lit(None).alias(col)) 195 return df
def
read_input_list(filename):
198def read_input_list(filename): 199 input_list = Path(filename).read_text().splitlines() 200 data = [] 201 for index, id_ in enumerate(input_list): 202 if not id_ or id_.strip() == "": 203 continue 204 205 category = categorize_id(id_) 206 record = { 207 "og_index": index, 208 "protein_id": None, 209 "nucleotide_id": None, 210 "uniprot_id": None, 211 "failed": None, 212 "failed_reason": None, 213 "gff_path": None, 214 "faa_path": None, 215 "fna_path": None, 216 "strand": None, 217 "start": None, 218 "end": None, 219 "gbf_path": None, 220 "taxid": None, 221 "assembly_id": None, 222 "input_type": None, 223 "premade": None, 224 } 225 226 if category["type"] == "protein": 227 record["protein_id"] = category["id"] 228 record["input_type"] = "protein" 229 elif category["type"] == "nucleotide": 230 record["nucleotide_id"] = category["id"] 231 record["protein_id"] = category.get("protein_id") 232 record["input_type"] = "nucleotide" 233 elif category["type"] == "uniprot": 234 record["uniprot_id"] = category["id"] 235 record["input_type"] = "protein" 236 elif category["type"] == "unmatched": 237 record["failed"] = True 238 record["failed_reason"] = "not valid ID" 239 else: 240 if ":" in id_: 241 category = categorize_id(id_.split(":")[0]) 242 if category["type"] == "nucleotide": 243 if ( 244 "-" in id_.split(":")[1] 245 and id_.split(":")[1].split("-")[0].isdigit() 246 and id_.split(":")[1].split("-")[1].isdigit() 247 ): 248 record["nucleotide_id"] = id_.split(":")[0] 249 record["start"] = id_.split(":")[1].split("-")[0] 250 record["end"] = id_.split(":")[1].split("-")[1] 251 if record["start"] > record["end"]: 252 record["start"], record["end"] = record["end"], record["start"] 253 record["strand"] = "-" 254 else: 255 record["strand"] = "+" 256 record["input_type"] = "nucleotide" 257 else: 258 record["failed"] = True 259 record["failed_reason"] = "not valid ID" 260 record["nucleotide_id"] = category["id"] 261 record["input_type"] = "nucleotide" 262 263 if record["protein_id"] is None and category.get("id"): 264 record["protein_id"] = category["id"] 265 record["input_type"] = record.get("input_type") or "protein" 266 data.append(record) 267 df = pl.DataFrame(data, infer_schema_length=len(data)) 268 return df
def
uniprot2ncbi(df: polars.dataframe.frame.DataFrame) -> polars.dataframe.frame.DataFrame:
271def uniprot2ncbi(df: pl.DataFrame) -> pl.DataFrame: 272 """ 273 Map UniProt accessions to NCBI protein IDs using UniProtMapper. 274 275 - Only attempts mapping for rows with `uniprot_id` present, `protein_id` null/empty, 276 and no local files (`gff_path`/`faa_path` missing). 277 - If mapping returns multiple hits, merged `To` values are applied directly. 278 - On failure, sets the `failed` column with a descriptive message. 279 """ 280 required_cols = {"uniprot_id", "protein_id", "gff_path", "faa_path"} 281 if not required_cols.issubset(set(df.columns)): 282 return df 283 284 # Cast protein_id to Utf8 to handle Null-type columns safely 285 protein_id_str = df["protein_id"].cast(pl.Utf8, strict=False) 286 mask = ( 287 df["uniprot_id"].is_not_null() 288 & (protein_id_str.is_null() | (protein_id_str == "")) 289 & df["gff_path"].is_null() 290 & df["faa_path"].is_null() 291 ) 292 293 if mask.sum() == 0: 294 return df 295 296 from UniProtMapper import ProtMapper 297 298 df_pd = df.to_pandas() 299 to_map = df_pd.loc[mask.to_pandas(), "uniprot_id"].dropna().unique().tolist() 300 if not to_map: 301 return df 302 303 mapper = ProtMapper() 304 try: 305 mapped_df, failed_ids = mapper.get( 306 ids=to_map, from_db="UniProtKB_AC-ID", to_db="EMBL-GenBank-DDBJ_CDS" 307 ) 308 except Exception: 309 failed_ids = to_map 310 mapped_df = None 311 312 if mapped_df is not None and not mapped_df.empty: 313 df_pd = df_pd.merge( 314 mapped_df[["From", "To"]], left_on="uniprot_id", right_on="From", how="left" 315 ) 316 df_pd["protein_id"] = df_pd["protein_id"].fillna(df_pd["To"]) 317 df_pd = df_pd.drop(columns=["From", "To"]) 318 319 if failed_ids: 320 failed_mask = df_pd["uniprot_id"].isin(failed_ids) 321 df_pd.loc[failed_mask, "failed"] = True 322 df_pd.loc[failed_mask, "failed_reason"] = "No associated NCBI found for the UniProt entry." 323 324 return pl.from_pandas(df_pd)
Map UniProt accessions to NCBI protein IDs using UniProtMapper.
- Only attempts mapping for rows with
uniprot_idpresent,protein_idnull/empty, and no local files (gff_path/faa_pathmissing). - If mapping returns multiple hits, merged
Tovalues are applied directly. - On failure, sets the
failedcolumn with a descriptive message.