fixed collection typing

This commit is contained in:
Hellow 2023-08-28 20:59:19 +02:00
parent 31a7740760
commit c18ac58bd2
2 changed files with 28 additions and 25 deletions

View File

@ -1,10 +1,13 @@
from typing import List, Iterable, Dict from typing import List, Iterable, Dict, TypeVar, Generic
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from .parents import DatabaseObject from .parents import DatabaseObject
T = TypeVar('T', bound=DatabaseObject)
@dataclass @dataclass
class AppendResult: class AppendResult:
was_in_collection: bool was_in_collection: bool
@ -12,21 +15,21 @@ class AppendResult:
was_the_same: bool was_the_same: bool
class Collection: class Collection(Generic[T]):
""" """
This a class for the iterables This a class for the iterables
like tracklist or discography like tracklist or discography
""" """
_data: List[DatabaseObject] _data: List[T]
_by_url: dict _by_url: dict
_by_attribute: dict _by_attribute: dict
def __init__(self, data: List[DatabaseObject] = None, element_type=None, *args, **kwargs) -> None: def __init__(self, data: List[T] = None, element_type=None, *args, **kwargs) -> None:
# Attribute needs to point to # Attribute needs to point to
self.element_type = element_type self.element_type = element_type
self._data: List[DatabaseObject] = list() self._data: List[T] = list()
""" """
example of attribute_to_object_map example of attribute_to_object_map
@ -40,7 +43,7 @@ class Collection:
} }
``` ```
""" """
self._attribute_to_object_map: Dict[str, Dict[object, DatabaseObject]] = defaultdict(dict) self._attribute_to_object_map: Dict[str, Dict[object, T]] = defaultdict(dict)
self._used_ids: set = set() self._used_ids: set = set()
if data is not None: if data is not None:
@ -49,7 +52,7 @@ class Collection:
def sort(self, reverse: bool = False, **kwargs): def sort(self, reverse: bool = False, **kwargs):
self._data.sort(reverse=reverse, **kwargs) self._data.sort(reverse=reverse, **kwargs)
def map_element(self, element: DatabaseObject): def map_element(self, element: T):
for name, value in element.indexing_values: for name, value in element.indexing_values:
if value is None: if value is None:
continue continue
@ -58,7 +61,7 @@ class Collection:
self._used_ids.add(element.id) self._used_ids.add(element.id)
def unmap_element(self, element: DatabaseObject): def unmap_element(self, element: T):
for name, value in element.indexing_values: for name, value in element.indexing_values:
if value is None: if value is None:
continue continue
@ -70,7 +73,7 @@ class Collection:
except KeyError: except KeyError:
pass pass
def append(self, element: DatabaseObject, merge_on_conflict: bool = True, def append(self, element: T, merge_on_conflict: bool = True,
merge_into_existing: bool = True) -> AppendResult: merge_into_existing: bool = True) -> AppendResult:
""" """
:param element: :param element:
@ -117,7 +120,7 @@ class Collection:
return AppendResult(False, element, False) return AppendResult(False, element, False)
def extend(self, element_list: Iterable[DatabaseObject], merge_on_conflict: bool = True, def extend(self, element_list: Iterable[T], merge_on_conflict: bool = True,
merge_into_existing: bool = True): merge_into_existing: bool = True):
for element in element_list: for element in element_list:
self.append(element, merge_on_conflict=merge_on_conflict, merge_into_existing=merge_into_existing) self.append(element, merge_on_conflict=merge_on_conflict, merge_into_existing=merge_into_existing)
@ -138,7 +141,7 @@ class Collection:
return self._data[key] return self._data[key]
def __setitem__(self, key, value: DatabaseObject): def __setitem__(self, key, value: T):
if type(key) != int: if type(key) != int:
return ValueError("key needs to be an integer") return ValueError("key needs to be an integer")
@ -149,7 +152,7 @@ class Collection:
self._data[key] = value self._data[key] = value
@property @property
def shallow_list(self) -> List[DatabaseObject]: def shallow_list(self) -> List[T]:
""" """
returns a shallow copy of the data list returns a shallow copy of the data list
""" """

View File

@ -82,11 +82,11 @@ class Song(MainObject):
self.notes: FormattedText = notes or FormattedText() self.notes: FormattedText = notes or FormattedText()
self.source_collection: SourceCollection = SourceCollection(source_list) self.source_collection: SourceCollection = SourceCollection(source_list)
self.target_collection: Collection = Collection(data=target_list, element_type=Target) self.target_collection: Collection[Target] = Collection(data=target_list, element_type=Target)
self.lyrics_collection: Collection = Collection(data=lyrics_list, element_type=Lyrics) self.lyrics_collection: Collection[Lyrics] = Collection(data=lyrics_list, element_type=Lyrics)
self.album_collection: Collection = Collection(data=album_list, element_type=Album) self.album_collection: Collection[Album] = Collection(data=album_list, element_type=Album)
self.main_artist_collection = Collection(data=main_artist_list, element_type=Artist) self.main_artist_collection: Collection[Artist] = Collection(data=main_artist_list, element_type=Artist)
self.feature_artist_collection = Collection(data=feature_artist_list, element_type=Artist) self.feature_artist_collection: Collection[Artist] = Collection(data=feature_artist_list, element_type=Artist)
def _build_recursive_structures(self, build_version: int, merge: bool): def _build_recursive_structures(self, build_version: int, merge: bool):
if build_version == self.build_version: if build_version == self.build_version:
@ -255,9 +255,9 @@ class Album(MainObject):
self.notes = notes or FormattedText() self.notes = notes or FormattedText()
self.source_collection: SourceCollection = SourceCollection(source_list) self.source_collection: SourceCollection = SourceCollection(source_list)
self.song_collection: Collection = Collection(data=song_list, element_type=Song) self.song_collection: Collection[Song] = Collection(data=song_list, element_type=Song)
self.artist_collection: Collection = Collection(data=artist_list, element_type=Artist) self.artist_collection: Collection[Artist] = Collection(data=artist_list, element_type=Artist)
self.label_collection: Collection = Collection(data=label_list, element_type=Label) self.label_collection: Collection[Label] = Collection(data=label_list, element_type=Label)
def _build_recursive_structures(self, build_version: int, merge: bool): def _build_recursive_structures(self, build_version: int, merge: bool):
if build_version == self.build_version: if build_version == self.build_version:
@ -481,9 +481,9 @@ class Artist(MainObject):
self.general_genre = general_genre self.general_genre = general_genre
self.source_collection: SourceCollection = SourceCollection(source_list) self.source_collection: SourceCollection = SourceCollection(source_list)
self.feature_song_collection: Collection = Collection(data=feature_song_list, element_type=Song) self.feature_song_collection: Collection[Song] = Collection(data=feature_song_list, element_type=Song)
self.main_album_collection: Collection = Collection(data=main_album_list, element_type=Album) self.main_album_collection: Collection[Album] = Collection(data=main_album_list, element_type=Album)
self.label_collection: Collection = Collection(data=label_list, element_type=Label) self.label_collection: Collection[Label] = Collection(data=label_list, element_type=Label)
def compile(self, merge_into: bool = False): def compile(self, merge_into: bool = False):
""" """
@ -685,8 +685,8 @@ class Label(MainObject):
self.notes = notes or FormattedText() self.notes = notes or FormattedText()
self.source_collection: SourceCollection = SourceCollection(source_list) self.source_collection: SourceCollection = SourceCollection(source_list)
self.album_collection: Collection = Collection(data=album_list, element_type=Album) self.album_collection: Collection[Album] = Collection(data=album_list, element_type=Album)
self.current_artist_collection: Collection = Collection(data=current_artist_list, element_type=Artist) self.current_artist_collection: Collection[Artist] = Collection(data=current_artist_list, element_type=Artist)
def _build_recursive_structures(self, build_version: int, merge: False): def _build_recursive_structures(self, build_version: int, merge: False):
if build_version == self.build_version: if build_version == self.build_version: