hoodini.utils.validation

Input validation and ID parsing utilities.

  1"""Input validation and ID parsing utilities."""
  2
  3from __future__ import annotations
  4
  5import csv
  6import re
  7from pathlib import Path
  8
  9import polars as pl
 10import rich_click as click
 11
 12from hoodini.utils.id_parsing import categorize_id
 13from hoodini.utils.logging_utils import warn
 14
 15
 16def validate_input_file(ctx, param, value):
 17    """Validate the input file based on its type, either single-column or TSV format."""
 18    if value is None:
 19        return None
 20
 21    if not Path(value).is_file():
 22        # Allow literal IDs/FASTA strings; runner will handle them.
 23        return value
 24
 25    try:
 26        with open(value) as file:
 27            if param.name == "input_path":
 28                lines = [line.strip() for line in file.readlines() if line.strip()]
 29                if len(lines) <= 1:
 30                    raise click.BadParameter("File must contain multiple lines.")
 31
 32                for line in lines:
 33                    if "," in line or "\t" in line:
 34                        raise click.BadParameter(
 35                            "File must be a single-column text file without delimiters like commas or tabs."
 36                        )
 37
 38                special_char_pattern = re.compile(r"[^A-Za-z0-9\s._]+")
 39                for i, line in enumerate(lines):
 40                    match = special_char_pattern.search(line)
 41                    if match:
 42                        raise click.BadParameter(
 43                            f"Invalid character '{match.group()}' found in line {i+1}: \"{line}\""
 44                        )
 45
 46            elif param.name == "inputsheet":
 47                first_line = file.readline()
 48                delimiter = "\t" if "\t" in first_line else None
 49                if delimiter is None:
 50                    raise click.BadParameter("The file is not in TSV format.")
 51                file.seek(0)
 52                reader = csv.DictReader(file, delimiter=delimiter)
 53                required_columns = [
 54                    "nucleotide_id",
 55                    "protein_id",
 56                    "gff_path",
 57                    "fna_path",
 58                    "faa_path",
 59                ]
 60                if not all(col in reader.fieldnames for col in required_columns):
 61                    raise click.BadParameter(
 62                        "TSV file must contain columns: nucleotide_id, protein_id, gff_path, fna_path, faa_path"
 63                    )
 64
 65                found_valid_row = False
 66                for row in reader:
 67                    if row["nucleotide_id"].strip() or row["protein_id"].strip():
 68                        found_valid_row = True
 69                        break
 70                if not found_valid_row:
 71                    raise click.BadParameter(
 72                        "The TSV file must contain at least one valid row with required data."
 73                    )
 74
 75    except Exception as e:
 76        raise click.BadParameter(f"Error reading file: {e}")
 77
 78    return value
 79
 80
 81def validate_domains(ctx, param, value):
 82    """Validate domain database names against MetaCerberus availability."""
 83    if not value or (isinstance(value, str) and value.strip() == ""):
 84        return None
 85
 86    db_names = [d.strip().lower() for d in value.split(",") if d.strip()]
 87    if not db_names:
 88        raise click.BadParameter("Domain parameter must contain at least one database name.")
 89
 90    for db in db_names:
 91        if not re.match(r"^[a-zA-Z0-9_-]+$", db):
 92            raise click.BadParameter(
 93                f"Invalid database name '{db}'. Only letters, numbers, underscores, and hyphens are allowed."
 94            )
 95
 96    try:
 97        from hoodini.download.metacerberus import check_downloaded, get_db_groups, list_db_files
 98
 99        files_list = list_db_files()
100        groups = get_db_groups(files_list)
101        status = check_downloaded(groups)
102
103        valid_dbs = []
104        invalid_dbs = []
105        missing_files_dbs = []
106
107        for db in db_names:
108            if db not in groups:
109                invalid_dbs.append(db)
110                continue
111
112            file_statuses = status.get(db, [])
113            hmm_present = any(
114                present for f, present in file_statuses if f["name"].endswith(".hmm.gz")
115            )
116            tsv_present = any(present for f, present in file_statuses if f["name"].endswith(".tsv"))
117
118            has_hmm_file = any(f["name"].endswith(".hmm.gz") for f, _ in file_statuses)
119            has_tsv_file = any(f["name"].endswith(".tsv") for f, _ in file_statuses)
120
121            if has_tsv_file and not tsv_present:
122                missing_files_dbs.append(f"{db} (missing TSV)")
123            elif has_hmm_file and not hmm_present:
124                missing_files_dbs.append(f"{db} (missing HMM)")
125            elif not has_hmm_file and not has_tsv_file:
126                invalid_dbs.append(db)
127            elif (has_hmm_file and hmm_present and has_tsv_file and tsv_present) or (
128                not has_hmm_file and has_tsv_file and tsv_present
129            ):
130                valid_dbs.append(db)
131            else:
132                missing_files_dbs.append(db)
133
134        if invalid_dbs:
135            raise click.BadParameter(
136                f"Unknown MetaCerberus databases: {', '.join(invalid_dbs)}. Run 'hoodini download metacerberus' to see available databases."
137            )
138
139        if missing_files_dbs:
140            raise click.BadParameter(
141                f"MetaCerberus databases not downloaded: {', '.join(missing_files_dbs)}. Run 'hoodini download metacerberus {','.join(db_names)}' to download them."
142            )
143
144        if not valid_dbs:
145            raise click.BadParameter("No valid MetaCerberus databases found.")
146
147        return valid_dbs
148
149    except ImportError as e:
150        warn(f"Could not validate MetaCerberus databases: {e}")
151        return value
152    except Exception as e:
153        warn(f"Could not validate MetaCerberus databases: {e}")
154        return value
155
156
157def switch_assembly_prefix(asm_id):
158    if not isinstance(asm_id, str):
159        return asm_id
160    if asm_id.startswith("GCA_"):
161        return "GCF_" + asm_id[4:]
162    if asm_id.startswith("GCF_"):
163        return "GCA_" + asm_id[4:]
164    return asm_id
165
166
167def is_refseq_nuccore(nuc_id):
168    """Return True if the nuccore accession is a RefSeq accession else False."""
169    refseq_prefixes = ("NC_", "NZ_", "NM_", "NR_", "XM_", "XR_", "AP_", "YP_", "XP_", "WP_")
170    return isinstance(nuc_id, str) and nuc_id.startswith(refseq_prefixes)
171
172
173def read_input_sheet(filename):
174    df = pl.read_csv(filename, separator="\t", dtype=str)
175    df = df.with_row_count("og_index").with_columns(pl.col("og_index").cast(pl.Utf8))
176
177    expected_columns = [
178        "protein_id",
179        "nucleotide_id",
180        "uniprot_id",
181        "gff_path",
182        "faa_path",
183        "fna_path",
184        "gbf_path",
185        "taxid",
186        "assembly_id",
187        "failed",
188        "input_type",
189        "premade",
190    ]
191    for col in expected_columns:
192        if col not in df.columns:
193            df = df.with_columns(pl.lit(None).alias(col))
194    return df
195
196
197def read_input_list(filename):
198    input_list = Path(filename).read_text().splitlines()
199    data = []
200    for index, id_ in enumerate(input_list):
201        if not id_ or id_.strip() == "":
202            continue
203
204        category = categorize_id(id_)
205        record = {
206            "og_index": index,
207            "protein_id": None,
208            "nucleotide_id": None,
209            "uniprot_id": None,
210            "failed": None,
211            "failed_reason": None,
212            "gff_path": None,
213            "faa_path": None,
214            "fna_path": None,
215            "strand": None,
216            "start": None,
217            "end": None,
218            "gbf_path": None,
219            "taxid": None,
220            "assembly_id": None,
221            "input_type": None,
222            "premade": None,
223        }
224
225        if category["type"] == "protein":
226            record["protein_id"] = category["id"]
227            record["input_type"] = "protein"
228        elif category["type"] == "nucleotide":
229            record["nucleotide_id"] = category["id"]
230            record["protein_id"] = category.get("protein_id")
231            record["input_type"] = "nucleotide"
232        elif category["type"] == "uniprot":
233            record["uniprot_id"] = category["id"]
234            record["input_type"] = "protein"
235        elif category["type"] == "unmatched":
236            record["failed"] = True
237            record["failed_reason"] = "not valid ID"
238        else:
239            if ":" in id_:
240                category = categorize_id(id_.split(":")[0])
241                if category["type"] == "nucleotide":
242                    if (
243                        "-" in id_.split(":")[1]
244                        and id_.split(":")[1].split("-")[0].isdigit()
245                        and id_.split(":")[1].split("-")[1].isdigit()
246                    ):
247                        record["nucleotide_id"] = id_.split(":")[0]
248                        record["start"] = id_.split(":")[1].split("-")[0]
249                        record["end"] = id_.split(":")[1].split("-")[1]
250                        if record["start"] > record["end"]:
251                            record["start"], record["end"] = record["end"], record["start"]
252                            record["strand"] = "-"
253                        else:
254                            record["strand"] = "+"
255                        record["input_type"] = "nucleotide"
256                    else:
257                        record["failed"] = True
258                        record["failed_reason"] = "not valid ID"
259                    record["nucleotide_id"] = category["id"]
260                    record["input_type"] = "nucleotide"
261
262            if record["protein_id"] is None and category.get("id"):
263                record["protein_id"] = category["id"]
264                record["input_type"] = record.get("input_type") or "protein"
265        data.append(record)
266    df = pl.DataFrame(data, infer_schema_length=len(data))
267    return df
268
269
270def uniprot2ncbi(df: pl.DataFrame) -> pl.DataFrame:
271    """
272    Map UniProt accessions to NCBI protein IDs using UniProtMapper.
273
274    - Only attempts mapping for rows with `uniprot_id` present, `protein_id` null/empty,
275      and no local files (`gff_path`/`faa_path` missing).
276    - If mapping returns multiple hits, merged `To` values are applied directly.
277    - On failure, sets the `failed` column with a descriptive message.
278    """
279    required_cols = {"uniprot_id", "protein_id", "gff_path", "faa_path"}
280    if not required_cols.issubset(set(df.columns)):
281        return df
282
283    # Cast protein_id to Utf8 to handle Null-type columns safely
284    protein_id_str = df["protein_id"].cast(pl.Utf8, strict=False)
285    mask = (
286        df["uniprot_id"].is_not_null()
287        & (protein_id_str.is_null() | (protein_id_str == ""))
288        & df["gff_path"].is_null()
289        & df["faa_path"].is_null()
290    )
291
292    if mask.sum() == 0:
293        return df
294
295    from UniProtMapper import ProtMapper
296
297    df_pd = df.to_pandas()
298    to_map = df_pd.loc[mask.to_pandas(), "uniprot_id"].dropna().unique().tolist()
299    if not to_map:
300        return df
301
302    mapper = ProtMapper()
303    try:
304        mapped_df, failed_ids = mapper.get(
305            ids=to_map, from_db="UniProtKB_AC-ID", to_db="EMBL-GenBank-DDBJ_CDS"
306        )
307    except Exception:
308        failed_ids = to_map
309        mapped_df = None
310
311    if mapped_df is not None and not mapped_df.empty:
312        df_pd = df_pd.merge(
313            mapped_df[["From", "To"]], left_on="uniprot_id", right_on="From", how="left"
314        )
315        df_pd["protein_id"] = df_pd["protein_id"].fillna(df_pd["To"])
316        df_pd = df_pd.drop(columns=["From", "To"])
317
318    if failed_ids:
319        failed_mask = df_pd["uniprot_id"].isin(failed_ids)
320        df_pd.loc[failed_mask, "failed"] = True
321        df_pd.loc[failed_mask, "failed_reason"] = "No associated NCBI found for the UniProt entry."
322
323    return pl.from_pandas(df_pd)
def validate_input_file(ctx, param, value):
17def validate_input_file(ctx, param, value):
18    """Validate the input file based on its type, either single-column or TSV format."""
19    if value is None:
20        return None
21
22    if not Path(value).is_file():
23        # Allow literal IDs/FASTA strings; runner will handle them.
24        return value
25
26    try:
27        with open(value) as file:
28            if param.name == "input_path":
29                lines = [line.strip() for line in file.readlines() if line.strip()]
30                if len(lines) <= 1:
31                    raise click.BadParameter("File must contain multiple lines.")
32
33                for line in lines:
34                    if "," in line or "\t" in line:
35                        raise click.BadParameter(
36                            "File must be a single-column text file without delimiters like commas or tabs."
37                        )
38
39                special_char_pattern = re.compile(r"[^A-Za-z0-9\s._]+")
40                for i, line in enumerate(lines):
41                    match = special_char_pattern.search(line)
42                    if match:
43                        raise click.BadParameter(
44                            f"Invalid character '{match.group()}' found in line {i+1}: \"{line}\""
45                        )
46
47            elif param.name == "inputsheet":
48                first_line = file.readline()
49                delimiter = "\t" if "\t" in first_line else None
50                if delimiter is None:
51                    raise click.BadParameter("The file is not in TSV format.")
52                file.seek(0)
53                reader = csv.DictReader(file, delimiter=delimiter)
54                required_columns = [
55                    "nucleotide_id",
56                    "protein_id",
57                    "gff_path",
58                    "fna_path",
59                    "faa_path",
60                ]
61                if not all(col in reader.fieldnames for col in required_columns):
62                    raise click.BadParameter(
63                        "TSV file must contain columns: nucleotide_id, protein_id, gff_path, fna_path, faa_path"
64                    )
65
66                found_valid_row = False
67                for row in reader:
68                    if row["nucleotide_id"].strip() or row["protein_id"].strip():
69                        found_valid_row = True
70                        break
71                if not found_valid_row:
72                    raise click.BadParameter(
73                        "The TSV file must contain at least one valid row with required data."
74                    )
75
76    except Exception as e:
77        raise click.BadParameter(f"Error reading file: {e}")
78
79    return value

Validate the input file based on its type, either single-column or TSV format.

def validate_domains(ctx, param, value):
 82def validate_domains(ctx, param, value):
 83    """Validate domain database names against MetaCerberus availability."""
 84    if not value or (isinstance(value, str) and value.strip() == ""):
 85        return None
 86
 87    db_names = [d.strip().lower() for d in value.split(",") if d.strip()]
 88    if not db_names:
 89        raise click.BadParameter("Domain parameter must contain at least one database name.")
 90
 91    for db in db_names:
 92        if not re.match(r"^[a-zA-Z0-9_-]+$", db):
 93            raise click.BadParameter(
 94                f"Invalid database name '{db}'. Only letters, numbers, underscores, and hyphens are allowed."
 95            )
 96
 97    try:
 98        from hoodini.download.metacerberus import check_downloaded, get_db_groups, list_db_files
 99
100        files_list = list_db_files()
101        groups = get_db_groups(files_list)
102        status = check_downloaded(groups)
103
104        valid_dbs = []
105        invalid_dbs = []
106        missing_files_dbs = []
107
108        for db in db_names:
109            if db not in groups:
110                invalid_dbs.append(db)
111                continue
112
113            file_statuses = status.get(db, [])
114            hmm_present = any(
115                present for f, present in file_statuses if f["name"].endswith(".hmm.gz")
116            )
117            tsv_present = any(present for f, present in file_statuses if f["name"].endswith(".tsv"))
118
119            has_hmm_file = any(f["name"].endswith(".hmm.gz") for f, _ in file_statuses)
120            has_tsv_file = any(f["name"].endswith(".tsv") for f, _ in file_statuses)
121
122            if has_tsv_file and not tsv_present:
123                missing_files_dbs.append(f"{db} (missing TSV)")
124            elif has_hmm_file and not hmm_present:
125                missing_files_dbs.append(f"{db} (missing HMM)")
126            elif not has_hmm_file and not has_tsv_file:
127                invalid_dbs.append(db)
128            elif (has_hmm_file and hmm_present and has_tsv_file and tsv_present) or (
129                not has_hmm_file and has_tsv_file and tsv_present
130            ):
131                valid_dbs.append(db)
132            else:
133                missing_files_dbs.append(db)
134
135        if invalid_dbs:
136            raise click.BadParameter(
137                f"Unknown MetaCerberus databases: {', '.join(invalid_dbs)}. Run 'hoodini download metacerberus' to see available databases."
138            )
139
140        if missing_files_dbs:
141            raise click.BadParameter(
142                f"MetaCerberus databases not downloaded: {', '.join(missing_files_dbs)}. Run 'hoodini download metacerberus {','.join(db_names)}' to download them."
143            )
144
145        if not valid_dbs:
146            raise click.BadParameter("No valid MetaCerberus databases found.")
147
148        return valid_dbs
149
150    except ImportError as e:
151        warn(f"Could not validate MetaCerberus databases: {e}")
152        return value
153    except Exception as e:
154        warn(f"Could not validate MetaCerberus databases: {e}")
155        return value

Validate domain database names against MetaCerberus availability.

def switch_assembly_prefix(asm_id):
158def switch_assembly_prefix(asm_id):
159    if not isinstance(asm_id, str):
160        return asm_id
161    if asm_id.startswith("GCA_"):
162        return "GCF_" + asm_id[4:]
163    if asm_id.startswith("GCF_"):
164        return "GCA_" + asm_id[4:]
165    return asm_id
def is_refseq_nuccore(nuc_id):
168def is_refseq_nuccore(nuc_id):
169    """Return True if the nuccore accession is a RefSeq accession else False."""
170    refseq_prefixes = ("NC_", "NZ_", "NM_", "NR_", "XM_", "XR_", "AP_", "YP_", "XP_", "WP_")
171    return isinstance(nuc_id, str) and nuc_id.startswith(refseq_prefixes)

Return True if the nuccore accession is a RefSeq accession else False.

def read_input_sheet(filename):
174def read_input_sheet(filename):
175    df = pl.read_csv(filename, separator="\t", dtype=str)
176    df = df.with_row_count("og_index").with_columns(pl.col("og_index").cast(pl.Utf8))
177
178    expected_columns = [
179        "protein_id",
180        "nucleotide_id",
181        "uniprot_id",
182        "gff_path",
183        "faa_path",
184        "fna_path",
185        "gbf_path",
186        "taxid",
187        "assembly_id",
188        "failed",
189        "input_type",
190        "premade",
191    ]
192    for col in expected_columns:
193        if col not in df.columns:
194            df = df.with_columns(pl.lit(None).alias(col))
195    return df
def read_input_list(filename):
198def read_input_list(filename):
199    input_list = Path(filename).read_text().splitlines()
200    data = []
201    for index, id_ in enumerate(input_list):
202        if not id_ or id_.strip() == "":
203            continue
204
205        category = categorize_id(id_)
206        record = {
207            "og_index": index,
208            "protein_id": None,
209            "nucleotide_id": None,
210            "uniprot_id": None,
211            "failed": None,
212            "failed_reason": None,
213            "gff_path": None,
214            "faa_path": None,
215            "fna_path": None,
216            "strand": None,
217            "start": None,
218            "end": None,
219            "gbf_path": None,
220            "taxid": None,
221            "assembly_id": None,
222            "input_type": None,
223            "premade": None,
224        }
225
226        if category["type"] == "protein":
227            record["protein_id"] = category["id"]
228            record["input_type"] = "protein"
229        elif category["type"] == "nucleotide":
230            record["nucleotide_id"] = category["id"]
231            record["protein_id"] = category.get("protein_id")
232            record["input_type"] = "nucleotide"
233        elif category["type"] == "uniprot":
234            record["uniprot_id"] = category["id"]
235            record["input_type"] = "protein"
236        elif category["type"] == "unmatched":
237            record["failed"] = True
238            record["failed_reason"] = "not valid ID"
239        else:
240            if ":" in id_:
241                category = categorize_id(id_.split(":")[0])
242                if category["type"] == "nucleotide":
243                    if (
244                        "-" in id_.split(":")[1]
245                        and id_.split(":")[1].split("-")[0].isdigit()
246                        and id_.split(":")[1].split("-")[1].isdigit()
247                    ):
248                        record["nucleotide_id"] = id_.split(":")[0]
249                        record["start"] = id_.split(":")[1].split("-")[0]
250                        record["end"] = id_.split(":")[1].split("-")[1]
251                        if record["start"] > record["end"]:
252                            record["start"], record["end"] = record["end"], record["start"]
253                            record["strand"] = "-"
254                        else:
255                            record["strand"] = "+"
256                        record["input_type"] = "nucleotide"
257                    else:
258                        record["failed"] = True
259                        record["failed_reason"] = "not valid ID"
260                    record["nucleotide_id"] = category["id"]
261                    record["input_type"] = "nucleotide"
262
263            if record["protein_id"] is None and category.get("id"):
264                record["protein_id"] = category["id"]
265                record["input_type"] = record.get("input_type") or "protein"
266        data.append(record)
267    df = pl.DataFrame(data, infer_schema_length=len(data))
268    return df
def uniprot2ncbi(df: polars.dataframe.frame.DataFrame) -> polars.dataframe.frame.DataFrame:
271def uniprot2ncbi(df: pl.DataFrame) -> pl.DataFrame:
272    """
273    Map UniProt accessions to NCBI protein IDs using UniProtMapper.
274
275    - Only attempts mapping for rows with `uniprot_id` present, `protein_id` null/empty,
276      and no local files (`gff_path`/`faa_path` missing).
277    - If mapping returns multiple hits, merged `To` values are applied directly.
278    - On failure, sets the `failed` column with a descriptive message.
279    """
280    required_cols = {"uniprot_id", "protein_id", "gff_path", "faa_path"}
281    if not required_cols.issubset(set(df.columns)):
282        return df
283
284    # Cast protein_id to Utf8 to handle Null-type columns safely
285    protein_id_str = df["protein_id"].cast(pl.Utf8, strict=False)
286    mask = (
287        df["uniprot_id"].is_not_null()
288        & (protein_id_str.is_null() | (protein_id_str == ""))
289        & df["gff_path"].is_null()
290        & df["faa_path"].is_null()
291    )
292
293    if mask.sum() == 0:
294        return df
295
296    from UniProtMapper import ProtMapper
297
298    df_pd = df.to_pandas()
299    to_map = df_pd.loc[mask.to_pandas(), "uniprot_id"].dropna().unique().tolist()
300    if not to_map:
301        return df
302
303    mapper = ProtMapper()
304    try:
305        mapped_df, failed_ids = mapper.get(
306            ids=to_map, from_db="UniProtKB_AC-ID", to_db="EMBL-GenBank-DDBJ_CDS"
307        )
308    except Exception:
309        failed_ids = to_map
310        mapped_df = None
311
312    if mapped_df is not None and not mapped_df.empty:
313        df_pd = df_pd.merge(
314            mapped_df[["From", "To"]], left_on="uniprot_id", right_on="From", how="left"
315        )
316        df_pd["protein_id"] = df_pd["protein_id"].fillna(df_pd["To"])
317        df_pd = df_pd.drop(columns=["From", "To"])
318
319    if failed_ids:
320        failed_mask = df_pd["uniprot_id"].isin(failed_ids)
321        df_pd.loc[failed_mask, "failed"] = True
322        df_pd.loc[failed_mask, "failed_reason"] = "No associated NCBI found for the UniProt entry."
323
324    return pl.from_pandas(df_pd)

Map UniProt accessions to NCBI protein IDs using UniProtMapper.

  • Only attempts mapping for rows with uniprot_id present, protein_id null/empty, and no local files (gff_path/faa_path missing).
  • If mapping returns multiple hits, merged To values are applied directly.
  • On failure, sets the failed column with a descriptive message.