Source code for vyperdatum.db

import os
import sqlite3
import logging
from typing import Optional, Tuple, Union
import pyproj as pp
from pathlib import Path
import pandas as pd
from vyperdatum.enums import PROJDB as enuPDB


logger = logging.getLogger("root_logger")


[docs] class DB: def __init__(self, db_dir: Optional[str] = pp.datadir.get_data_dir() ) -> None: """ Parameters ---------- db_dir: str Path to the directory where the proj database file is located. The default values is pyproj data directory (`pyproj.datadir`). Returns ------- NoneType """ self.db_dir = db_dir self.db_name = enuPDB.FILE_NAME.value self.db_file_path = Path(self.db_dir).joinpath(self.db_name) if db_dir != pp.datadir.get_data_dir(): self.update_db_path() return def __str__(self): return (f"db_dir: {self.db_dir}\ndb_file_path: {self.db_file_path}" "\nproj.datadir: {pp.datadir.get_data_dir()}" ) def __repr__(self): return f"DB(data_dir = r'{self.db_dir}')" @property def db_file_path(self) -> str: return self._db_file_path @db_file_path.setter def db_file_path(self, db_path: str): if not db_path.is_file(): raise FileNotFoundError(f"Proj Database file not found at: {db_path}") self.db_dir = os.path.dirname(db_path) self.db_name = os.path.basename(db_path) self._db_file_path = db_path
[docs] def update_db_path(self) -> bool: """ Prepend `self.db_dir` to `pyproj.datadir` which guides the pyproj to first look for the database at `self.db_dir` address. Raises ------- ValueError: If `.db_dir` is not set. FileNotFoundError: If the database file is not found. Returns ------- bool: `True` if the `data_dir` is set successfully, otherwise `False`. """ try: success = False if not self.db_dir: raise ValueError("Attribute `.db_dir` not specified.") pp.datadir.set_data_dir(self.db_dir + ";" + pp.datadir.get_data_dir()) if pp.datadir.get_data_dir().split(";")[0] != self.db_dir: raise SystemExit("Unable to set the path to the custom PROJ database.") else: success = True finally: return success
def _connect(self) -> Tuple[Optional[object], Optional[object]]: """ Connect to the proj database. Returns ----------- obj: database connection object obj: database cursor object """ try: con, cur = None, None con = sqlite3.connect(self.db_file_path) cur = con.cursor() except Exception: logger.exception(f"Error while connecting to database at {self.db_file_path}.") return con, cur
[docs] def query(self, sql: str, dataframe: bool = False ) -> Union[Optional[list], Optional[pd.DataFrame]]: """ Execute a sql query and return the response. This method is intended to run a read (scan) query. Avoid using this method for DML/DDL type operations. Parameters ---------- sql: str SQL query (intended to be a scan query) to be executed. dataframe: bool, default=False If True, converts the result into a pandas dataframe. Returns -------- list or pd.DataFrame """ try: con, cur, res = None, None, None con, cur = self._connect() cur.execute(sql) res = cur.fetchall() # logger.info(res) if dataframe: res = pd.DataFrame.from_records(res, columns=[column[0] for column in cur.description] ) except Exception: logger.exception("Error in db.query.") finally: if cur: cur.close() if con: con.close() return res
[docs] def crs_by_keyword(self, keywords: list[str], dataframe: bool = False ) -> Union[Optional[list], Optional[pd.DataFrame]]: """ Return a list (or dataframe) of CRS that their name or description contain the passed keywords. The search is not case-sensitive. Parameters ---------- sql: str SQL query (intended to be a scan query) to be executed. keywords: list[str] A list of string keywords used to query the database. dataframe: bool, default=False If True, converts the result into a pandas dataframe. Raises ------- TypeError: If `keywords` is not a list of strings. ValueError: If no keywords is passed. Returns -------- list or pd.DataFrame """ if not isinstance(keywords, list): raise TypeError("keywords must be a list") if not all(map(lambda k: isinstance(k, str), keywords)): raise TypeError("keywords must be a list of strings") if len(keywords) < 1: raise ValueError("at least one keyword most be passed") where = "" filters = [f"(name like '%{k}%' OR description like '%{k}%')" for k in keywords] if len(filters) > 0: where = "WHERE " + " AND ".join(filters) return self.query(sql=f"SELECT * FROM {enuPDB.VIEW_CRS.value} {where}", dataframe=dataframe)