diff --git a/src/music_kraken/database/write.py b/src/music_kraken/database/write.py index 3b38c48..5bddeff 100644 --- a/src/music_kraken/database/write.py +++ b/src/music_kraken/database/write.py @@ -1,9 +1,11 @@ -from typing import Union, Set, Optional, Dict +from typing import Union, Set, Optional, Dict, DefaultDict +from collections import defaultdict import traceback from peewee import ( SqliteDatabase, MySQLDatabase, PostgresqlDatabase, + Model ) from . import objects @@ -41,6 +43,8 @@ class Session: self.added_album_ids: Dict[str] = dict() self.added_artist_ids: Dict[str] = dict() + self.db_objects: DefaultDict[data_models.BaseModel, list] = defaultdict(list) + def __enter__(self, database: Database): """ Enter the context of the database session @@ -70,7 +74,7 @@ class Session: traceback.print_tb(exc_tb) print(f"Exception of type {exc_type} occurred with message: {exc_val}") - self.commit() + self.commit(reset=False) return exc_val is None def add_source(self, source: objects.Source, connected_to: data_models.Source.ContentTypes) -> data_models.Source: @@ -81,6 +85,8 @@ class Session: content_object=connected_to ).use(self.database) + self.db_objects[data_models.Source].append(db_source) + return db_source def add_lyrics(self, lyrics: objects.Lyrics, song: data_models.Song) -> data_models.Lyrics: @@ -91,6 +97,8 @@ class Session: song=song ).use(self.database) + self.db_objects[data_models.Lyrics].append(db_lyrics) + for source in lyrics.source_list: self.add_source(source=source, connected_to=db_lyrics) @@ -104,6 +112,8 @@ class Session: song=song ).use(self.database) + self.db_objects[data_models.Target].append(db_target) + return db_target def add_song(self, song: objects.Song) -> Optional[data_models.Song]: @@ -127,6 +137,7 @@ class Session: genre=song.genre ).use(self.database) + self.db_objects[data_models.Song].append(db_song) self.added_song_ids[song.id] = db_song for source in song.source_list: @@ -171,6 +182,7 @@ class Session: db_album = data_models.Album().use(self.database) + self.db_objects[data_models.Album].append(db_album) self.added_album_ids.add(album.id) return db_album @@ -189,12 +201,23 @@ class Session: db_artist = data_models.Artist() + self.db_objects[data_models.Artist].append(db_artist) self.added_artist_ids[artist.id] = db_artist return db_artist - def commit(self): + def commit(self, reset: bool = True): """ Commit changes to the database """ - pass + + for model, model_instance_list in self.db_objects.items(): + model.Use(self.database).insert_many(model_instance_list) + + if reset: + self.__init__(self.database) + + +if __name__ == "__main__": + with Session(SqliteDatabase(":memory:")) as session: + session.add_song(objects.Song(title="Hs")) \ No newline at end of file