diff --git a/mealie/db/db_base.py b/mealie/db/db_base.py index f9e9a03a9b86..fe385f5a2d55 100644 --- a/mealie/db/db_base.py +++ b/mealie/db/db_base.py @@ -2,6 +2,7 @@ from typing import List from mealie.db.models.model_base import SqlAlchemyBase from pydantic import BaseModel +from sqlalchemy import func from sqlalchemy.orm import load_only from sqlalchemy.orm.session import Session @@ -64,7 +65,9 @@ class BaseDocument: return session.query(self.sql_model).filter_by(**{match_key: match_value}).one() - def get(self, session: Session, match_value: str, match_key: str = None, limit=1) -> BaseModel or List[BaseModel]: + def get( + self, session: Session, match_value: str, match_key: str = None, limit=1, any_case=False + ) -> BaseModel or List[BaseModel]: """Retrieves an entry from the database by matching a key/value pair. If no key is provided the class objects primary key will be used to match against. @@ -80,7 +83,13 @@ class BaseDocument: if match_key is None: match_key = self.primary_key - result = session.query(self.sql_model).filter_by(**{match_key: match_value}).limit(limit).all() + if any_case: + search_attr = getattr(self.sql_model, match_key) + result = ( + session.query(self.sql_model).filter(func.lower(search_attr) == match_value.lower()).limit(limit).all() + ) + else: + result = session.query(self.sql_model).filter_by(**{match_key: match_value}).limit(limit).all() if limit == 1: try: diff --git a/mealie/schema/user.py b/mealie/schema/user.py index 197706e518f1..78d4f831ecba 100644 --- a/mealie/schema/user.py +++ b/mealie/schema/user.py @@ -6,6 +6,7 @@ from mealie.db.models.group import Group from mealie.db.models.users import User from mealie.schema.category import CategoryBase from mealie.schema.meal import MealPlanInDB +from pydantic.types import constr from pydantic.utils import GetterDict @@ -23,7 +24,7 @@ class GroupBase(CamelModel): class UserBase(CamelModel): full_name: Optional[str] = None - email: str + email: constr(to_lower=True, strip_whitespace=True) admin: bool group: Optional[str] @@ -31,7 +32,7 @@ class UserBase(CamelModel): orm_mode = True @classmethod - def getter_dict(_cls, name_orm: User): + def getter_dict(cls, name_orm: User): return { **GetterDict(name_orm), "group": name_orm.group.name,