diff --git a/backend/adapters/orm.py b/backend/adapters/orm.py index 74ef842..3670f2c 100644 --- a/backend/adapters/orm.py +++ b/backend/adapters/orm.py @@ -1,9 +1,12 @@ -from sqlalchemy import Column, MetaData, String, Table -from sqlalchemy.orm import mapper +from sqlalchemy import Column, ForeignKey, MetaData, String, Table +from sqlalchemy.orm import backref, registry, relationship +from backend.model.assessment import Assessment +from backend.model.student import Student from backend.model.tribe import Tribe metadata = MetaData() +mapper_registry = registry() tribes_table = Table( "tribes", @@ -12,6 +15,35 @@ tribes_table = Table( Column("level", String(255)), ) +assessments_table = Table( + "assessments", + metadata, + Column("id", String(255), primary_key=True), + Column("name", String(255)), + Column("tribe_name", String(255), ForeignKey("tribes.name")), +) + +students_table = Table( + "students", + metadata, + Column("id", String(255), primary_key=True), + Column("name", String(255)), + Column("tribe_name", String(255), ForeignKey("tribes.name")), +) + def start_mappers(): - tribes_mapper = mapper(Tribe, tribes_table) + tribes_mapper = mapper_registry.map_imperatively( + Tribe, + tribes_table, + properties={ + "students": relationship( + Student, backref="tribes", order_by=students_table.c.id + ), + "assessments": relationship( + Assessment, backref="tribes", order_by=assessments_table.c.id + ), + }, + ) + students_mapper = mapper_registry.map_imperatively(Student, students_table) + assessments_mapper = mapper_registry.map_imperatively(Assessment, assessments_table) diff --git a/backend/model/student.py b/backend/model/student.py index e227c69..4404eca 100644 --- a/backend/model/student.py +++ b/backend/model/student.py @@ -20,3 +20,6 @@ class Student: if isinstance(other, Student): return self.id == other.id return False + + def __hash__(self) -> int: + return hash(self.id) diff --git a/backend/model/tribe.py b/backend/model/tribe.py index bb1b713..e6ff677 100644 --- a/backend/model/tribe.py +++ b/backend/model/tribe.py @@ -26,8 +26,5 @@ class Tribe: return self.name == other.name return False - def __repr__(self) -> str: - return f"" - def to_dict(self) -> dict: return {"name": self.name, "level": self.level} diff --git a/tests/integration/test_orm.py b/tests/integration/test_orm.py index a137e8c..b9cbc7e 100644 --- a/tests/integration/test_orm.py +++ b/tests/integration/test_orm.py @@ -1,4 +1,5 @@ from backend.adapters.orm import metadata, start_mappers +from backend.model.student import Student from backend.model.tribe import Tribe @@ -40,3 +41,25 @@ def test_tribe_mapper_can_save_and_load_tribe(session): session.commit() assert session.query(Tribe).all() == [tribe] + + +def test_students_mapper_can_load_student(session): + session.execute("INSERT INTO tribes (name, level) VALUES " "('tribe1', '2nd')") + + session.execute( + "INSERT INTO students (id, name, tribe_name) VALUES " + "('1', 'student1', 'tribe1')," + "('2', 'student2', 'tribe1')" + ) + + tribe = session.query(Tribe).one() + expected = [ + (Student("1", "student1", tribe)), + (Student("2", "student2", tribe)), + ] + + with session.no_autoflush: + students = session.query(Student).all() + + assert set(tribe.students) == set(expected) + assert students == expected