Update write.py

This commit is contained in:
Hellow2 2023-02-17 12:18:47 +01:00
parent 4e7fc5b9b3
commit 6531bf7022

View File

@ -1,9 +1,11 @@
from typing import Union, Set, Optional, Dict from typing import Union, Set, Optional, Dict, DefaultDict
from collections import defaultdict
import traceback import traceback
from peewee import ( from peewee import (
SqliteDatabase, SqliteDatabase,
MySQLDatabase, MySQLDatabase,
PostgresqlDatabase, PostgresqlDatabase,
Model
) )
from . import objects from . import objects
@ -41,6 +43,8 @@ class Session:
self.added_album_ids: Dict[str] = dict() self.added_album_ids: Dict[str] = dict()
self.added_artist_ids: Dict[str] = dict() self.added_artist_ids: Dict[str] = dict()
self.db_objects: DefaultDict[data_models.BaseModel, list] = defaultdict(list)
def __enter__(self, database: Database): def __enter__(self, database: Database):
""" """
Enter the context of the database session Enter the context of the database session
@ -70,7 +74,7 @@ class Session:
traceback.print_tb(exc_tb) traceback.print_tb(exc_tb)
print(f"Exception of type {exc_type} occurred with message: {exc_val}") print(f"Exception of type {exc_type} occurred with message: {exc_val}")
self.commit() self.commit(reset=False)
return exc_val is None return exc_val is None
def add_source(self, source: objects.Source, connected_to: data_models.Source.ContentTypes) -> data_models.Source: def add_source(self, source: objects.Source, connected_to: data_models.Source.ContentTypes) -> data_models.Source:
@ -81,6 +85,8 @@ class Session:
content_object=connected_to content_object=connected_to
).use(self.database) ).use(self.database)
self.db_objects[data_models.Source].append(db_source)
return db_source return db_source
def add_lyrics(self, lyrics: objects.Lyrics, song: data_models.Song) -> data_models.Lyrics: def add_lyrics(self, lyrics: objects.Lyrics, song: data_models.Song) -> data_models.Lyrics:
@ -91,6 +97,8 @@ class Session:
song=song song=song
).use(self.database) ).use(self.database)
self.db_objects[data_models.Lyrics].append(db_lyrics)
for source in lyrics.source_list: for source in lyrics.source_list:
self.add_source(source=source, connected_to=db_lyrics) self.add_source(source=source, connected_to=db_lyrics)
@ -104,6 +112,8 @@ class Session:
song=song song=song
).use(self.database) ).use(self.database)
self.db_objects[data_models.Target].append(db_target)
return db_target return db_target
def add_song(self, song: objects.Song) -> Optional[data_models.Song]: def add_song(self, song: objects.Song) -> Optional[data_models.Song]:
@ -127,6 +137,7 @@ class Session:
genre=song.genre genre=song.genre
).use(self.database) ).use(self.database)
self.db_objects[data_models.Song].append(db_song)
self.added_song_ids[song.id] = db_song self.added_song_ids[song.id] = db_song
for source in song.source_list: for source in song.source_list:
@ -171,6 +182,7 @@ class Session:
db_album = data_models.Album().use(self.database) db_album = data_models.Album().use(self.database)
self.db_objects[data_models.Album].append(db_album)
self.added_album_ids.add(album.id) self.added_album_ids.add(album.id)
return db_album return db_album
@ -189,12 +201,23 @@ class Session:
db_artist = data_models.Artist() db_artist = data_models.Artist()
self.db_objects[data_models.Artist].append(db_artist)
self.added_artist_ids[artist.id] = db_artist self.added_artist_ids[artist.id] = db_artist
return db_artist return db_artist
def commit(self): def commit(self, reset: bool = True):
""" """
Commit changes to the database Commit changes to the database
""" """
pass
for model, model_instance_list in self.db_objects.items():
model.Use(self.database).insert_many(model_instance_list)
if reset:
self.__init__(self.database)
if __name__ == "__main__":
with Session(SqliteDatabase(":memory:")) as session:
session.add_song(objects.Song(title="Hs"))