2023-05-05 14:43:23 -08:00

195 lines
6.6 KiB
Python

from functools import wraps
from uuid import UUID
from pydantic import BaseModel, Field, NoneStr
from sqlalchemy import select
from sqlalchemy.orm import MANYTOMANY, MANYTOONE, ONETOMANY, Session
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.relationships import RelationshipProperty
from sqlalchemy.sql.base import ColumnCollection
from .._model_base import SqlAlchemyBase
from .helpers import safe_call
def _default_exclusion() -> set[str]:
return {"id"}
class AutoInitConfig(BaseModel):
"""
Config class for `auto_init` decorator.
"""
get_attr: NoneStr = None
exclude: set = Field(default_factory=_default_exclusion)
# auto_create: bool = False
def _get_config(relation_cls: type[SqlAlchemyBase]) -> AutoInitConfig:
"""
Returns the config for the given class.
"""
cfg = AutoInitConfig()
cfgKeys = cfg.dict().keys()
# Get the config for the class
try:
class_config: AutoInitConfig = relation_cls.Config
except AttributeError:
return cfg
# Map all matching attributes in Config to all AutoInitConfig attributes
for attr in dir(class_config):
if attr in cfgKeys:
setattr(cfg, attr, getattr(class_config, attr))
return cfg
def get_lookup_attr(relation_cls: type[SqlAlchemyBase]) -> str:
"""Returns the primary key attribute of the related class as a string.
Args:
relation_cls (DeclarativeMeta): The SQLAlchemy class to get the primary_key from
Returns:
Any: [description]
"""
cfg = _get_config(relation_cls)
try:
get_attr = cfg.get_attr
if get_attr is None:
get_attr = relation_cls.__table__.primary_key.columns.keys()[0]
except Exception:
get_attr = "id"
return get_attr
def handle_many_to_many(session, get_attr, relation_cls, all_elements: list[dict]):
"""
Proxy call to `handle_one_to_many_list` for many-to-many relationships. Because functionally, they do the same
"""
return handle_one_to_many_list(session, get_attr, relation_cls, all_elements)
def handle_one_to_many_list(
session: Session, get_attr, relation_cls: type[SqlAlchemyBase], all_elements: list[dict] | list[str]
):
elems_to_create: list[dict] = []
updated_elems: list[dict] = []
cfg = _get_config(relation_cls)
for elem in all_elements:
elem_id = elem.get(get_attr, None) if isinstance(elem, dict) else elem
stmt = select(relation_cls).filter_by(**{get_attr: elem_id})
existing_elem = session.execute(stmt).scalars().one_or_none()
if existing_elem is None and isinstance(elem, dict):
elems_to_create.append(elem)
continue
elif isinstance(elem, dict):
for key, value in elem.items():
if key not in cfg.exclude:
setattr(existing_elem, key, value)
updated_elems.append(existing_elem)
new_elems = [safe_call(relation_cls, elem, session=session) for elem in elems_to_create]
return new_elems + updated_elems
def auto_init(): # sourcery no-metrics
"""Wraps the `__init__` method of a class to automatically set the common
attributes.
Args:
exclude (Union[set, list], optional): [description]. Defaults to None.
"""
def decorator(init):
@wraps(init)
def wrapper(self: SqlAlchemyBase, *args, **kwargs): # sourcery no-metrics
"""
Custom initializer that allows nested children initialization.
Only keys that are present as instance's class attributes are allowed.
These could be, for example, any mapped columns or relationships.
Code inspired from GitHub.
Ref: https://github.com/tiangolo/fastapi/issues/2194
"""
cls = self.__class__
config = _get_config(cls)
exclude = config.exclude
alchemy_mapper: Mapper = self.__mapper__
model_columns: ColumnCollection = alchemy_mapper.columns
relationships = alchemy_mapper.relationships
session: Session = kwargs.get("session", None)
if session is None:
raise ValueError("Session is required to initialize the model with `auto_init`")
for key, val in kwargs.items():
if key in exclude:
continue
if not hasattr(cls, key):
continue
# raise TypeError(f"Invalid keyword argument: {key}")
if key in model_columns:
setattr(self, key, val)
continue
if key in relationships:
prop: RelationshipProperty = relationships[key]
# Identifies the type of relationship (ONETOMANY, MANYTOONE, many-to-one, many-to-many)
relation_dir = prop.direction
# Identifies the parent class of the related object.
relation_cls: type[SqlAlchemyBase] = prop.mapper.entity
# Identifies if the relationship was declared with use_list=True
use_list: bool = prop.uselist
get_attr = get_lookup_attr(relation_cls)
if relation_dir == ONETOMANY and use_list:
instances = handle_one_to_many_list(session, get_attr, relation_cls, val)
setattr(self, key, instances)
elif relation_dir == ONETOMANY:
instance = safe_call(relation_cls, val, session=session)
setattr(self, key, instance)
elif relation_dir == MANYTOONE and not use_list:
if isinstance(val, dict):
val = val.get(get_attr)
if val is None:
raise ValueError(f"Expected 'id' to be provided for {key}")
if isinstance(val, str | int | UUID):
stmt = select(relation_cls).filter_by(**{get_attr: val})
instance = session.execute(stmt).scalars().one_or_none()
setattr(self, key, instance)
else:
# If the value is not of the type defined above we assume that it isn't a valid id
# and try a different approach.
pass
elif relation_dir == MANYTOMANY:
instances = handle_many_to_many(session, get_attr, relation_cls, val)
setattr(self, key, instances)
return init(self, *args, **kwargs)
return wrapper
return decorator