sqlmodel的知识点特别的零散,就用AI封装了一下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from sqlmodel import SQLModel, Session, select, Field, create_engine
from typing import TypeVar, Type, List, Optional, Any
from sqlalchemy.exc import SQLAlchemyError

T = TypeVar("T", bound=SQLModel)

class BaseRepository:
def __init__(self, session: Session):
self.session = session

def add(self, model: T) -> T:
self.session.add(model)
self.session.commit()
self.session.refresh(model)
return model

def get_by_id(self, model_type: Type[T], id: Any) -> Optional[T]:
return self.session.get(model_type, id)

def get_all(self, model_type: Type[T]) -> List[T]:
return self.session.exec(select(model_type)).all()

def update(self, model: T) -> T:
self.session.add(model)
self.session.commit()
self.session.refresh(model)
return model

def delete(self, model: T) -> None:
self.session.delete(model)
self.session.commit()

def filter(self, model_type: Type[T], **kwargs) -> List[T]:
query = select(model_type).filter_by(**kwargs)
return self.session.exec(query).all()

def begin_transaction(self):
self.session.begin()

def commit(self):
self.session.commit()

def rollback(self):
self.session.rollback()



# 用户模型
class User(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
name: str
email: str

class UserRepository(BaseRepository):
def get_by_email(self, email: str) -> Optional[User]:
return self.session.exec(select(User).where(User.email == email)).first()

# 可以添加更多特定于 User 的方法


# 数据库会话管理
from contextlib import contextmanager

DATABASE_URL = "sqlite:///database.sqlite"
engine = create_engine(DATABASE_URL, echo=True)

def create_db_and_tables():
SQLModel.metadata.create_all(engine)

@contextmanager
def get_session():
session = Session(engine)
try:
yield session
session.commit()
except SQLAlchemyError:
session.rollback()
raise
finally:
session.close()

# 使用示例
if __name__ == "__main__":
create_db_and_tables()
with get_session() as session:
user_repo = UserRepository(session)

# 添加用户
new_user = User(name="John Doe", email="[email protected]")
added_user = user_repo.add(new_user)
print(f"Added user: {added_user.id}")

# 获取用户
user = user_repo.get_by_id(User, 4)
print(f"Retrieved user: {user.name}")

# 更新用户
user.name = "Jane Doe"
updated_user = user_repo.update(user)
print(f"Updated user: {updated_user.name}")

# 获取所有用户
all_users = user_repo.get_all(User)
print(f"Total users: {len(all_users)}")

# 按条件筛选用户
filtered_users = user_repo.filter(User, name="Jane Doe")
print(f"Filtered users: {len(filtered_users)}")

# 删除用户
user_repo.delete(user)
print("User deleted")


# 事务操作
repo = UserRepository(session)
try:
repo.begin_transaction()
# 执行多个操作
user1 = repo.add(User(name="User1", email="[email protected]"))
user2 = repo.add(User(name="User2", email="[email protected]"))
repo.commit()
except Exception:
repo.rollback()
raise