Skip to content

Commit 4885b99

Browse files
committed
add model_dump options to sqlmodel_update
1 parent 9c01aa6 commit 4885b99

2 files changed

Lines changed: 20 additions & 16 deletions

File tree

sqlmodel/main.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -984,25 +984,17 @@ def sqlmodel_update(
984984
obj: builtins.dict[str, Any] | BaseModel,
985985
*,
986986
update: builtins.dict[str, Any] | None = None,
987+
**model_dump_kwargs,
987988
) -> _TSQLModel:
988-
use_update = (update or {}).copy()
989-
if isinstance(obj, dict):
990-
for key, value in {**obj, **use_update}.items():
991-
if key in get_model_fields(self):
992-
setattr(self, key, value)
993-
elif isinstance(obj, BaseModel):
994-
for key in get_model_fields(obj):
995-
if key in use_update:
996-
value = use_update.pop(key)
997-
else:
998-
value = getattr(obj, key)
999-
setattr(self, key, value)
1000-
for remaining_key, value in use_update.items():
1001-
if remaining_key in get_model_fields(self):
1002-
setattr(self, remaining_key, value)
1003-
else:
989+
if not (isinstance(obj, dict) or isinstance(obj, BaseModel)):
1004990
raise ValueError(
1005991
"Can't use sqlmodel_update() with something that "
1006992
f"is not a dict or SQLModel or Pydantic model: {obj}"
1007993
)
994+
if isinstance(obj, BaseModel):
995+
obj = obj.model_dump(**model_dump_kwargs)
996+
use_update = (update or {}).copy()
997+
for key, value in {**obj, **use_update}.items():
998+
if key in get_model_fields(self):
999+
setattr(self, key, value)
10081000
return self

tests/test_update.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
from pytest import raises
12
from sqlmodel import Field, SQLModel
23

34

45
def test_sqlmodel_update():
56
class Organization(SQLModel, table=True):
67
id: int = Field(default=None, primary_key=True)
78
name: str
9+
city: str
810
headquarters: str
911

1012
class OrganizationUpdate(SQLModel):
1113
name: str
14+
city: str | None = None
1215

1316
org = Organization(name="Example Org", city="New York", headquarters="NYC HQ")
1417
org_in = OrganizationUpdate(name="Updated org")
@@ -17,4 +20,13 @@ class OrganizationUpdate(SQLModel):
1720
update={
1821
"headquarters": "-", # This field is in Organization, but not in OrganizationUpdate
1922
},
23+
exclude_unset=True
2024
)
25+
# fields that should stay the same
26+
assert org.city == "New York"
27+
#fields that should be updated
28+
assert org.name == "Updated org"
29+
assert org.headquarters == "-"
30+
# test raise value error when passing in updates other than dict or BaseModel
31+
with raises(ValueError):
32+
org.sqlmodel_update(["Boston"])

0 commit comments

Comments
 (0)