Skip to content

wg_utilities.clients

Useful clients for commonly accessed APIs/services.

GoogleCalendarClient

Bases: GoogleClient[GoogleCalendarEntityJson]

Custom client specifically for Google's Calendar API.

Source code in wg_utilities/clients/google_calendar.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
class GoogleCalendarClient(GoogleClient[GoogleCalendarEntityJson]):
    """Custom client specifically for Google's Calendar API."""

    BASE_URL = "https://www.googleapis.com/calendar/v3"

    DEFAULT_PARAMS: ClassVar[
        dict[StrBytIntFlt, StrBytIntFlt | Iterable[StrBytIntFlt] | None]
    ] = {
        "maxResults": "250",
    }

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/calendar",
        "https://www.googleapis.com/auth/calendar.events",
    ]

    _primary_calendar: Calendar

    def create_event(
        self,
        summary: str,
        start_datetime: datetime_ | date_,
        end_datetime: datetime_ | date_,
        tz: str | None = None,
        calendar: Calendar | None = None,
        extra_params: dict[str, str] | None = None,
    ) -> Event:
        """Create an event.

        Args:
            summary (str): the summary (title) of the event
            start_datetime (Union[datetime, date]): when the event starts
            end_datetime (Union[datetime, date]): when the event ends
            tz (str): the timezone which the event is in (IANA database name)
            calendar (Calendar): the calendar to add the event to
            extra_params (dict): any extra params to pass in the request

        Returns:
            Event: a new event instance, fresh out of the oven

        Raises:
            TypeError: if the start/end datetime params are not the correct type
        """

        calendar = calendar or self.primary_calendar
        tz = tz or str(get_localzone())

        start_params = {
            "timeZone": tz,
        }

        if isinstance(start_datetime, datetime_):
            start_params["dateTime"] = start_datetime.isoformat()
        elif isinstance(start_datetime, date_):
            start_params["date"] = start_datetime.isoformat()
        else:
            raise TypeError("`start_datetime` must be either a date or a datetime")

        end_params = {
            "timeZone": tz,
        }

        if isinstance(end_datetime, datetime_):
            end_params["dateTime"] = end_datetime.isoformat()
        elif isinstance(end_datetime, date_):
            end_params["date"] = end_datetime.isoformat()
        else:
            raise TypeError("`end_datetime` must be either a date or a datetime")

        event_json = self.post_json_response(
            f"/calendars/{calendar.id}/events",
            json={
                "summary": summary,
                "start": start_params,
                "end": end_params,
                **(extra_params or {}),
            },
            params={"maxResults": None},
        )

        return Event.from_json_response(event_json, calendar=calendar, google_client=self)

    def delete_event_by_id(self, event_id: str, calendar: Calendar | None = None) -> None:
        """Delete an event from a calendar.

        Args:
            event_id (str): the ID of the event to delete
            calendar (Calendar): the calendar being updated
        """
        calendar = calendar or self.primary_calendar

        res = delete(
            f"{self.base_url}/calendars/{calendar.id}/events/{event_id}",
            headers=self.request_headers,
            timeout=10,
        )

        res.raise_for_status()

    def get_event_by_id(
        self,
        event_id: str,
        *,
        calendar: Calendar | None = None,
    ) -> Event:
        """Get a specific event by ID.

        Args:
            event_id (str): the ID of the event to delete
            calendar (Calendar): the calendar being updated

        Returns:
            Event: an Event instance with all relevant attributes
        """
        calendar = calendar or self.primary_calendar

        return Event.from_json_response(
            self.get_json_response(
                f"/calendars/{calendar.id}/events/{event_id}",
                params={"maxResults": None},
            ),
            calendar=calendar,
            google_client=self,
        )

    @property
    def calendar_list(self) -> list[Calendar]:
        """List of calendars.

        Returns:
            list: a list of Calendar instances that the user has access to
        """
        return [
            Calendar.from_json_response(cal_json, google_client=self)
            for cal_json in self.get_items(
                "/users/me/calendarList",
                params={"maxResults": None},
            )
        ]

    @property
    def primary_calendar(self) -> Calendar:
        """Primary calendar for the user.

        Returns:
            Calendar: the current user's primary calendar
        """
        if not hasattr(self, "_primary_calendar"):
            self._primary_calendar = Calendar.from_json_response(
                self.get_json_response("/calendars/primary", params={"maxResults": None}),
                google_client=self,
            )

        return self._primary_calendar

calendar_list: list[Calendar] property

List of calendars.

Returns:

Name Type Description
list list[Calendar]

a list of Calendar instances that the user has access to

primary_calendar: Calendar property

Primary calendar for the user.

Returns:

Name Type Description
Calendar Calendar

the current user's primary calendar

create_event(summary, start_datetime, end_datetime, tz=None, calendar=None, extra_params=None)

Create an event.

Parameters:

Name Type Description Default
summary str

the summary (title) of the event

required
start_datetime Union[datetime, date]

when the event starts

required
end_datetime Union[datetime, date]

when the event ends

required
tz str

the timezone which the event is in (IANA database name)

None
calendar Calendar

the calendar to add the event to

None
extra_params dict

any extra params to pass in the request

None

Returns:

Name Type Description
Event Event

a new event instance, fresh out of the oven

Raises:

Type Description
TypeError

if the start/end datetime params are not the correct type

Source code in wg_utilities/clients/google_calendar.py
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def create_event(
    self,
    summary: str,
    start_datetime: datetime_ | date_,
    end_datetime: datetime_ | date_,
    tz: str | None = None,
    calendar: Calendar | None = None,
    extra_params: dict[str, str] | None = None,
) -> Event:
    """Create an event.

    Args:
        summary (str): the summary (title) of the event
        start_datetime (Union[datetime, date]): when the event starts
        end_datetime (Union[datetime, date]): when the event ends
        tz (str): the timezone which the event is in (IANA database name)
        calendar (Calendar): the calendar to add the event to
        extra_params (dict): any extra params to pass in the request

    Returns:
        Event: a new event instance, fresh out of the oven

    Raises:
        TypeError: if the start/end datetime params are not the correct type
    """

    calendar = calendar or self.primary_calendar
    tz = tz or str(get_localzone())

    start_params = {
        "timeZone": tz,
    }

    if isinstance(start_datetime, datetime_):
        start_params["dateTime"] = start_datetime.isoformat()
    elif isinstance(start_datetime, date_):
        start_params["date"] = start_datetime.isoformat()
    else:
        raise TypeError("`start_datetime` must be either a date or a datetime")

    end_params = {
        "timeZone": tz,
    }

    if isinstance(end_datetime, datetime_):
        end_params["dateTime"] = end_datetime.isoformat()
    elif isinstance(end_datetime, date_):
        end_params["date"] = end_datetime.isoformat()
    else:
        raise TypeError("`end_datetime` must be either a date or a datetime")

    event_json = self.post_json_response(
        f"/calendars/{calendar.id}/events",
        json={
            "summary": summary,
            "start": start_params,
            "end": end_params,
            **(extra_params or {}),
        },
        params={"maxResults": None},
    )

    return Event.from_json_response(event_json, calendar=calendar, google_client=self)

delete_event_by_id(event_id, calendar=None)

Delete an event from a calendar.

Parameters:

Name Type Description Default
event_id str

the ID of the event to delete

required
calendar Calendar

the calendar being updated

None
Source code in wg_utilities/clients/google_calendar.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def delete_event_by_id(self, event_id: str, calendar: Calendar | None = None) -> None:
    """Delete an event from a calendar.

    Args:
        event_id (str): the ID of the event to delete
        calendar (Calendar): the calendar being updated
    """
    calendar = calendar or self.primary_calendar

    res = delete(
        f"{self.base_url}/calendars/{calendar.id}/events/{event_id}",
        headers=self.request_headers,
        timeout=10,
    )

    res.raise_for_status()

get_event_by_id(event_id, *, calendar=None)

Get a specific event by ID.

Parameters:

Name Type Description Default
event_id str

the ID of the event to delete

required
calendar Calendar

the calendar being updated

None

Returns:

Name Type Description
Event Event

an Event instance with all relevant attributes

Source code in wg_utilities/clients/google_calendar.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
def get_event_by_id(
    self,
    event_id: str,
    *,
    calendar: Calendar | None = None,
) -> Event:
    """Get a specific event by ID.

    Args:
        event_id (str): the ID of the event to delete
        calendar (Calendar): the calendar being updated

    Returns:
        Event: an Event instance with all relevant attributes
    """
    calendar = calendar or self.primary_calendar

    return Event.from_json_response(
        self.get_json_response(
            f"/calendars/{calendar.id}/events/{event_id}",
            params={"maxResults": None},
        ),
        calendar=calendar,
        google_client=self,
    )

GoogleDriveClient

Bases: GoogleClient[JSONObj]

Custom client specifically for Google's Drive API.

Parameters:

Name Type Description Default
scopes list

a list of scopes the client can be given

None
creds_cache_path str

file path for where to cache credentials

None
Source code in wg_utilities/clients/google_drive.py
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
class GoogleDriveClient(GoogleClient[JSONObj]):
    """Custom client specifically for Google's Drive API.

    Args:
        scopes (list): a list of scopes the client can be given
        creds_cache_path (str): file path for where to cache credentials
    """

    BASE_URL = "https://www.googleapis.com/drive/v3"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/drive",
        "https://www.googleapis.com/auth/drive.file",
        "https://www.googleapis.com/auth/drive.readonly",
        "https://www.googleapis.com/auth/drive.metadata.readonly",
        "https://www.googleapis.com/auth/drive.appdata",
        "https://www.googleapis.com/auth/drive.metadata",
        "https://www.googleapis.com/auth/drive.photos.readonly",
    ]

    _my_drive: Drive

    def __init__(  # noqa: PLR0913
        self,
        *,
        client_id: str,
        client_secret: str,
        log_requests: bool = False,
        creds_cache_path: Path | None = None,
        creds_cache_dir: Path | None = None,
        scopes: list[str] | None = None,
        oauth_login_redirect_host: str = "localhost",
        oauth_redirect_uri_override: str | None = None,
        headless_auth_link_callback: Callable[[str], None] | None = None,
        use_existing_credentials_only: bool = False,
        validate_request_success: bool = True,
        item_metadata_retrieval: IMR = IMR.ON_FIRST_REQUEST,
    ):
        super().__init__(
            client_id=client_id,
            client_secret=client_secret,
            log_requests=log_requests,
            creds_cache_path=creds_cache_path,
            creds_cache_dir=creds_cache_dir,
            scopes=scopes or self.DEFAULT_SCOPES,
            oauth_login_redirect_host=oauth_login_redirect_host,
            oauth_redirect_uri_override=oauth_redirect_uri_override,
            headless_auth_link_callback=headless_auth_link_callback,
            use_existing_credentials_only=use_existing_credentials_only,
            base_url=self.BASE_URL,
            validate_request_success=validate_request_success,
        )

        self.item_metadata_retrieval = item_metadata_retrieval

    @property
    def my_drive(self) -> Drive:
        """User's personal Drive.

        Returns:
            Drive: the user's root directory/main Drive
        """
        if not hasattr(self, "_my_drive"):
            self._my_drive = Drive.from_json_response(
                self.get_json_response(
                    "/files/root",
                    params={"fields": "*", "pageSize": None},
                ),
                google_client=self,
            )

        return self._my_drive

    @property
    def shared_drives(self) -> list[Drive]:
        """Get a list of all shared drives.

        Returns:
            list: a list of Shared Drives the current user has access to
        """
        return [
            Drive.from_json_response(
                drive,
                google_client=self,
            )
            for drive in self.get_items(
                "/drives",
                list_key="drives",
                params={"fields": "*"},
            )
        ]

my_drive: Drive property

User's personal Drive.

Returns:

Name Type Description
Drive Drive

the user's root directory/main Drive

shared_drives: list[Drive] property

Get a list of all shared drives.

Returns:

Name Type Description
list list[Drive]

a list of Shared Drives the current user has access to

GoogleFitClient

Bases: GoogleClient[Any]

Custom client for interacting with the Google Fit API.

See Also

GoogleClient: the base Google client, used for authentication and common functions

Source code in wg_utilities/clients/google_fit.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class GoogleFitClient(GoogleClient[Any]):
    """Custom client for interacting with the Google Fit API.

    See Also:
        GoogleClient: the base Google client, used for authentication and common functions
    """

    BASE_URL = "https://www.googleapis.com/fitness/v1"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/fitness.activity.read",
        "https://www.googleapis.com/auth/fitness.body.read",
        "https://www.googleapis.com/auth/fitness.location.read",
        "https://www.googleapis.com/auth/fitness.nutrition.read",
    ]

    _data_sources: dict[str, DataSource]

    def get_data_source(self, data_source_id: str) -> DataSource:
        """Get a data source based on its UID.

        DataSource instances are cached for the lifetime of the GoogleClient instance

        Args:
            data_source_id (str): the UID of the data source

        Returns:
            DataSource: an instance, ready to use!
        """

        if (data_source := self.data_sources.get(data_source_id)) is None:
            data_source = DataSource(data_source_id=data_source_id, google_client=self)
            self.data_sources[data_source_id] = data_source

        return data_source

    @property
    def data_sources(self) -> dict[str, DataSource]:
        """Data sources available to this client.

        Returns:
            dict: a dict of data sources, keyed by their UID
        """
        if not hasattr(self, "_data_sources"):
            self._data_sources = {}

        return self._data_sources

data_sources: dict[str, DataSource] property

Data sources available to this client.

Returns:

Name Type Description
dict dict[str, DataSource]

a dict of data sources, keyed by their UID

get_data_source(data_source_id)

Get a data source based on its UID.

DataSource instances are cached for the lifetime of the GoogleClient instance

Parameters:

Name Type Description Default
data_source_id str

the UID of the data source

required

Returns:

Name Type Description
DataSource DataSource

an instance, ready to use!

Source code in wg_utilities/clients/google_fit.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def get_data_source(self, data_source_id: str) -> DataSource:
    """Get a data source based on its UID.

    DataSource instances are cached for the lifetime of the GoogleClient instance

    Args:
        data_source_id (str): the UID of the data source

    Returns:
        DataSource: an instance, ready to use!
    """

    if (data_source := self.data_sources.get(data_source_id)) is None:
        data_source = DataSource(data_source_id=data_source_id, google_client=self)
        self.data_sources[data_source_id] = data_source

    return data_source

GooglePhotosClient

Bases: GoogleClient[GooglePhotosEntityJson]

Custom client for interacting with the Google Photos API.

See Also

GoogleClient: the base Google client, used for authentication and common functions

Source code in wg_utilities/clients/google_photos.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class GooglePhotosClient(GoogleClient[GooglePhotosEntityJson]):
    """Custom client for interacting with the Google Photos API.

    See Also:
        GoogleClient: the base Google client, used for authentication and common functions
    """

    BASE_URL = "https://photoslibrary.googleapis.com/v1"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/photoslibrary.readonly",
        "https://www.googleapis.com/auth/photoslibrary.appendonly",
        "https://www.googleapis.com/auth/photoslibrary.readonly.appcreateddata",
        "https://www.googleapis.com/auth/photoslibrary.edit.appcreateddata",
    ]

    _albums: list[Album]
    # Only really used to check if all album metadata has been fetched, not
    # available to the user (would still require caching all albums).
    _album_count: int

    def get_album_by_id(self, album_id: str) -> Album:
        """Get an album by its ID.

        Args:
            album_id (str): the ID of the album to fetch

        Returns:
            Album: the album with the given ID
        """

        if hasattr(self, "_albums"):
            for album in self._albums:
                if album.id == album_id:
                    return album

        album = Album.from_json_response(
            self.get_json_response(f"/albums/{album_id}", params={"pageSize": None}),
            google_client=self,
        )

        if not hasattr(self, "_albums"):
            self._albums = [album]
        else:
            self._albums.append(album)

        return album

    def get_album_by_name(self, album_name: str) -> Album:
        """Get an album definition from the Google API based on the album name.

        Args:
            album_name (str): the name of the album to find

        Returns:
            Album: an Album instance, with all metadata etc.

        Raises:
            FileNotFoundError: if the client can't find an album with the correct name
        """

        LOGGER.info("Getting metadata for album `%s`", album_name)
        for album in self.albums:
            if album.title == album_name:
                return album

        raise FileNotFoundError(f"Unable to find album with name {album_name!r}.")

    @property
    def albums(self) -> list[Album]:
        """List all albums in the active Google account.

        Returns:
            list: a list of Album instances
        """

        if not hasattr(self, "_albums"):
            self._albums = [
                Album.from_json_response(item, google_client=self)
                for item in self.get_items(
                    f"{self.BASE_URL}/albums",
                    list_key="albums",
                    params={"pageSize": 50},
                )
            ]
            self._album_count = len(self._albums)
        elif not hasattr(self, "_album_count"):
            album_ids = [album.id for album in self._albums]
            self._albums.extend(
                [
                    Album.from_json_response(item, google_client=self)
                    for item in self.get_items(
                        f"{self.BASE_URL}/albums",
                        list_key="albums",
                        params={"pageSize": 50},
                    )
                    if item["id"] not in album_ids
                ],
            )

            self._album_count = len(self._albums)

        return self._albums

albums: list[Album] property

List all albums in the active Google account.

Returns:

Name Type Description
list list[Album]

a list of Album instances

get_album_by_id(album_id)

Get an album by its ID.

Parameters:

Name Type Description Default
album_id str

the ID of the album to fetch

required

Returns:

Name Type Description
Album Album

the album with the given ID

Source code in wg_utilities/clients/google_photos.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def get_album_by_id(self, album_id: str) -> Album:
    """Get an album by its ID.

    Args:
        album_id (str): the ID of the album to fetch

    Returns:
        Album: the album with the given ID
    """

    if hasattr(self, "_albums"):
        for album in self._albums:
            if album.id == album_id:
                return album

    album = Album.from_json_response(
        self.get_json_response(f"/albums/{album_id}", params={"pageSize": None}),
        google_client=self,
    )

    if not hasattr(self, "_albums"):
        self._albums = [album]
    else:
        self._albums.append(album)

    return album

get_album_by_name(album_name)

Get an album definition from the Google API based on the album name.

Parameters:

Name Type Description Default
album_name str

the name of the album to find

required

Returns:

Name Type Description
Album Album

an Album instance, with all metadata etc.

Raises:

Type Description
FileNotFoundError

if the client can't find an album with the correct name

Source code in wg_utilities/clients/google_photos.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def get_album_by_name(self, album_name: str) -> Album:
    """Get an album definition from the Google API based on the album name.

    Args:
        album_name (str): the name of the album to find

    Returns:
        Album: an Album instance, with all metadata etc.

    Raises:
        FileNotFoundError: if the client can't find an album with the correct name
    """

    LOGGER.info("Getting metadata for album `%s`", album_name)
    for album in self.albums:
        if album.title == album_name:
            return album

    raise FileNotFoundError(f"Unable to find album with name {album_name!r}.")

ItemMetadataRetrieval

Bases: StrEnum

The type of metadata retrieval to use for items.

Attributes:

Name Type Description
ON_DEMAND str

only retrieves single metadata items on demand. Best for reducing memory usage but makes most HTTP requests.

ON_FIRST_REQUEST str

retrieves all metadata items on the first request for any metadata value. Nice middle ground between memory usage and HTTP requests.

ON_INIT str

retrieves metadata on instance initialisation. Increases memory usage, makes the fewest HTTP requests. If combined with a Drive.map call, it can be used to preload all metadata for the entire Drive.

Source code in wg_utilities/clients/google_drive.py
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
class ItemMetadataRetrieval(StrEnum):
    """The type of metadata retrieval to use for items.

    Attributes:
        ON_DEMAND (str): only retrieves single metadata items on demand. Best for
            reducing memory usage but makes most HTTP requests.
        ON_FIRST_REQUEST (str): retrieves all metadata items on the first request for
            _any_ metadata value. Nice middle ground between memory usage and HTTP
            requests.
        ON_INIT (str): retrieves metadata on instance initialisation. Increases memory
            usage, makes the fewest HTTP requests. If combined with a `Drive.map` call,
            it can be used to preload all metadata for the entire Drive.
    """

    ON_DEMAND = "on_demand"
    ON_FIRST_REQUEST = "on_first_request"
    ON_INIT = "on_init"

JsonApiClient

Bases: Generic[GetJsonResponse]

Generic no-auth JSON API client to simplify interactions.

Sort of an SDK?

Source code in wg_utilities/clients/json_api_client.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
class JsonApiClient(Generic[GetJsonResponse]):
    """Generic no-auth JSON API client to simplify interactions.

    Sort of an SDK?
    """

    BASE_URL: str

    DEFAULT_PARAMS: ClassVar[
        dict[StrBytIntFlt, StrBytIntFlt | Iterable[StrBytIntFlt] | None]
    ] = {}

    def __init__(
        self,
        *,
        log_requests: bool = False,
        base_url: str | None = None,
        validate_request_success: bool = True,
    ):
        self.base_url = base_url or self.BASE_URL
        self.log_requests = log_requests
        self.validate_request_success = validate_request_success

    def _get(
        self,
        url: str,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> Response:
        """Wrap all GET requests to cover authentication, URL parsing, etc. etc.

        Args:
            url (str): the URL path to the endpoint (not necessarily including the
                base URL)
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): any headers to override the default headers
            timeout (float | tuple[float, float] | tuple[float, None] | None): the
                timeout for the request
            json (Any): the JSON to be passed in the HTTP request
            data (Any): the data to be passed in the HTTP request

        Returns:
            Response: the response from the HTTP request
        """
        return self._request(
            method=get,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    def _post(
        self,
        url: str,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> Response:
        """Wrap all POST requests to cover authentication, URL parsing, etc. etc.

        Args:
            url (str): the URL path to the endpoint (not necessarily including the
                base URL)
            json (dict): the data to be passed in the HTTP request
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): any headers to override the default headers
            timeout (float | tuple[float, float] | tuple[float, None] | None): the
                timeout for the request
            json (Any): the JSON to be passed in the HTTP request
            data (Any): the data to be passed in the HTTP request

        Returns:
            Response: the response from the HTTP request
        """
        return self._request(
            method=post,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    def _request(
        self,
        *,
        method: Callable[..., Response],
        url: str,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> Response:
        """Make a HTTP request.

        Args:
            method (Callable): the HTTP method to use
            url (str): the URL path to the endpoint (not necessarily including the
                base URL)
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): any headers to override the default headers
            timeout (float | tuple[float, float] | tuple[float, None] | None): the
                timeout for the request
            json (dict): the data to be passed in the HTTP request
            data (dict): the data to be passed in the HTTP request
        """
        if params is not None:
            params.update(
                {k: v for k, v in self.DEFAULT_PARAMS.items() if k not in params},
            )
        else:
            params = deepcopy(self.DEFAULT_PARAMS)

        params = {k: v for k, v in params.items() if v is not None}

        if url.startswith("/"):
            url = f"{self.base_url}{url}"

        if self.log_requests:
            LOGGER.debug(
                "%s %s: %s",
                method.__name__.upper(),
                url,
                dumps(params, default=str),
            )

        res = method(
            url,
            headers=(
                header_overrides if header_overrides is not None else self.request_headers
            ),
            params=params,
            timeout=timeout,
            json=json,
            data=data,
        )

        if self.validate_request_success:
            res.raise_for_status()

        return res

    def _request_json_response(
        self,
        *,
        method: Callable[..., Response],
        url: str,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> GetJsonResponse:
        res = self._request(
            method=method,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )
        if res.status_code == HTTPStatus.NO_CONTENT:
            return {}  # type: ignore[return-value]

        try:
            return res.json()  # type: ignore[no-any-return]
        except JSONDecodeError as exc:
            if not res.content:
                return {}  # type: ignore[return-value]

            raise ValueError(res.text) from exc

    def get_json_response(
        self,
        url: str,
        /,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> GetJsonResponse:
        """Get a simple JSON object from a URL.

        Args:
            url (str): the API endpoint to GET
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): headers to add to/overwrite the headers in
                `self.request_headers`. Setting this to an empty dict will erase all
                headers; `None` will use `self.request_headers`.
            timeout (float): How many seconds to wait for the server to send data
                before giving up
            json (dict): a JSON payload to pass in the request
            data (dict): a data payload to pass in the request

        Returns:
            dict: the JSON from the response
        """

        return self._request_json_response(
            method=get,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    def post_json_response(
        self,
        url: str,
        /,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> GetJsonResponse:
        """Get a simple JSON object from a URL from a POST request.

        Args:
            url (str): the API endpoint to GET
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): headers to add to/overwrite the headers in
                `self.request_headers`. Setting this to an empty dict will erase all
                headers; `None` will use `self.request_headers`.
            timeout (float): How many seconds to wait for the server to send data
                before giving up
            json (dict): a JSON payload to pass in the request
            data (dict): a data payload to pass in the request

        Returns:
            dict: the JSON from the response
        """

        return self._request_json_response(
            method=post,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    @property
    def request_headers(self) -> dict[str, str]:
        """Header to be used in requests to the API.

        Returns:
            dict: auth headers for HTTP requests
        """
        return {
            "Content-Type": "application/json",
        }

request_headers: dict[str, str] property

Header to be used in requests to the API.

Returns:

Name Type Description
dict dict[str, str]

auth headers for HTTP requests

get_json_response(url, /, *, params=None, header_overrides=None, timeout=None, json=None, data=None)

Get a simple JSON object from a URL.

Parameters:

Name Type Description Default
url str

the API endpoint to GET

required
params dict

the parameters to be passed in the HTTP request

None
header_overrides dict

headers to add to/overwrite the headers in self.request_headers. Setting this to an empty dict will erase all headers; None will use self.request_headers.

None
timeout float

How many seconds to wait for the server to send data before giving up

None
json dict

a JSON payload to pass in the request

None
data dict

a data payload to pass in the request

None

Returns:

Name Type Description
dict GetJsonResponse

the JSON from the response

Source code in wg_utilities/clients/json_api_client.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def get_json_response(
    self,
    url: str,
    /,
    *,
    params: (
        dict[
            StrBytIntFlt,
            StrBytIntFlt | Iterable[StrBytIntFlt] | None,
        ]
        | None
    ) = None,
    header_overrides: Mapping[str, str | bytes] | None = None,
    timeout: float | None = None,
    json: Any | None = None,
    data: Any | None = None,
) -> GetJsonResponse:
    """Get a simple JSON object from a URL.

    Args:
        url (str): the API endpoint to GET
        params (dict): the parameters to be passed in the HTTP request
        header_overrides (dict): headers to add to/overwrite the headers in
            `self.request_headers`. Setting this to an empty dict will erase all
            headers; `None` will use `self.request_headers`.
        timeout (float): How many seconds to wait for the server to send data
            before giving up
        json (dict): a JSON payload to pass in the request
        data (dict): a data payload to pass in the request

    Returns:
        dict: the JSON from the response
    """

    return self._request_json_response(
        method=get,
        url=url,
        params=params,
        header_overrides=header_overrides,
        timeout=timeout,
        json=json,
        data=data,
    )

post_json_response(url, /, *, params=None, header_overrides=None, timeout=None, json=None, data=None)

Get a simple JSON object from a URL from a POST request.

Parameters:

Name Type Description Default
url str

the API endpoint to GET

required
params dict

the parameters to be passed in the HTTP request

None
header_overrides dict

headers to add to/overwrite the headers in self.request_headers. Setting this to an empty dict will erase all headers; None will use self.request_headers.

None
timeout float

How many seconds to wait for the server to send data before giving up

None
json dict

a JSON payload to pass in the request

None
data dict

a data payload to pass in the request

None

Returns:

Name Type Description
dict GetJsonResponse

the JSON from the response

Source code in wg_utilities/clients/json_api_client.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def post_json_response(
    self,
    url: str,
    /,
    *,
    params: (
        dict[
            StrBytIntFlt,
            StrBytIntFlt | Iterable[StrBytIntFlt] | None,
        ]
        | None
    ) = None,
    header_overrides: Mapping[str, str | bytes] | None = None,
    timeout: float | tuple[float, float] | tuple[float, None] | None = None,
    json: Any | None = None,
    data: Any | None = None,
) -> GetJsonResponse:
    """Get a simple JSON object from a URL from a POST request.

    Args:
        url (str): the API endpoint to GET
        params (dict): the parameters to be passed in the HTTP request
        header_overrides (dict): headers to add to/overwrite the headers in
            `self.request_headers`. Setting this to an empty dict will erase all
            headers; `None` will use `self.request_headers`.
        timeout (float): How many seconds to wait for the server to send data
            before giving up
        json (dict): a JSON payload to pass in the request
        data (dict): a data payload to pass in the request

    Returns:
        dict: the JSON from the response
    """

    return self._request_json_response(
        method=post,
        url=url,
        params=params,
        header_overrides=header_overrides,
        timeout=timeout,
        json=json,
        data=data,
    )

MonzoClient

Bases: OAuthClient[MonzoGJR]

Custom client for interacting with Monzo's API.

Source code in wg_utilities/clients/monzo.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
class MonzoClient(OAuthClient[MonzoGJR]):
    """Custom client for interacting with Monzo's API."""

    ACCESS_TOKEN_ENDPOINT = "https://api.monzo.com/oauth2/token"  # noqa: S105
    AUTH_LINK_BASE = "https://auth.monzo.com"
    BASE_URL = "https://api.monzo.com"

    DEFAULT_PARAMS: ClassVar[
        dict[StrBytIntFlt, StrBytIntFlt | Iterable[StrBytIntFlt] | None]
    ] = {}

    _current_account: Account

    def deposit_into_pot(
        self,
        pot: Pot,
        amount_pence: int,
        dedupe_id: str | None = None,
    ) -> None:
        """Move money from the user's account into one of their pots.

        Args:
            pot (Pot): the target pot
            amount_pence (int): the amount of money to deposit, in pence
            dedupe_id (str): unique string used to de-duplicate deposits. Will be
                created if not provided
        """

        dedupe_id = dedupe_id or "|".join(
            [pot.id, str(amount_pence), str(utcnow(DTU.SECOND))],
        )

        res = put(
            f"{self.BASE_URL}/pots/{pot.id}/deposit",
            headers=self.request_headers,
            data={
                "source_account_id": self.current_account.id,
                "amount": amount_pence,
                "dedupe_id": dedupe_id,
            },
            timeout=10,
        )
        res.raise_for_status()

    def list_accounts(
        self,
        *,
        include_closed: bool = False,
        account_type: str | None = None,
    ) -> list[Account]:
        """Get a list of the user's accounts.

        Args:
            include_closed (bool): whether to include closed accounts in the response
            account_type (str): the type of account(s) to find; submitted as param in
                request

        Returns:
            list: Account instances, containing all related info
        """

        res = self.get_json_response(
            "/accounts",
            params={"account_type": account_type} if account_type else None,
        )

        return [
            Account.from_json_response(account, self)
            for account in res.get("accounts", [])
            if not account.get("closed", True) or include_closed
        ]

    def list_pots(self, *, include_deleted: bool = False) -> list[Pot]:
        """Get a list of the user's pots.

        Args:
            include_deleted (bool): whether to include deleted pots in the response

        Returns:
            list: Pot instances, containing all related info
        """

        res = self.get_json_response(
            "/pots",
            params={"current_account_id": self.current_account.id},
        )

        return [
            Pot(**pot)
            for pot in res.get("pots", [])
            if not pot.get("deleted", True) or include_deleted
        ]

    def get_pot_by_id(self, pot_id: str) -> Pot | None:
        """Get a pot from its ID.

        Args:
            pot_id (str): the ID of the pot to find

        Returns:
            Pot: the Pot instance
        """
        for pot in self.list_pots(include_deleted=True):
            if pot.id == pot_id:
                return pot

        return None

    def get_pot_by_name(
        self,
        pot_name: str,
        *,
        exact_match: bool = False,
        include_deleted: bool = False,
    ) -> Pot | None:
        """Get a pot from its name.

        Args:
            pot_name (str): the name of the pot to find
            exact_match (bool): if False, all pot names will be cleansed before
                evaluation
            include_deleted (bool): whether to include deleted pots in the response

        Returns:
            Pot: the Pot instance
        """
        if not exact_match:
            pot_name = cleanse_string(pot_name)

        for pot in self.list_pots(include_deleted=include_deleted):
            found_name = pot.name if exact_match else cleanse_string(pot.name)
            if found_name.lower() == pot_name.lower():
                return pot

        return None

    @property
    def current_account(self) -> Account:
        """Get the main account for the Monzo user.

        We assume there'll only be one main account per user.

        Returns:
            Account: the user's main account, instantiated
        """
        if not hasattr(self, "_current_account"):
            self._current_account = self.list_accounts(account_type="uk_retail")[0]

        return self._current_account

    @property
    def request_headers(self) -> dict[str, str]:
        """Header to be used in requests to the API.

        Returns:
            dict: auth headers for HTTP requests
        """
        return {
            "Authorization": f"Bearer {self.access_token}",
        }

current_account: Account property

Get the main account for the Monzo user.

We assume there'll only be one main account per user.

Returns:

Name Type Description
Account Account

the user's main account, instantiated

request_headers: dict[str, str] property

Header to be used in requests to the API.

Returns:

Name Type Description
dict dict[str, str]

auth headers for HTTP requests

deposit_into_pot(pot, amount_pence, dedupe_id=None)

Move money from the user's account into one of their pots.

Parameters:

Name Type Description Default
pot Pot

the target pot

required
amount_pence int

the amount of money to deposit, in pence

required
dedupe_id str

unique string used to de-duplicate deposits. Will be created if not provided

None
Source code in wg_utilities/clients/monzo.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def deposit_into_pot(
    self,
    pot: Pot,
    amount_pence: int,
    dedupe_id: str | None = None,
) -> None:
    """Move money from the user's account into one of their pots.

    Args:
        pot (Pot): the target pot
        amount_pence (int): the amount of money to deposit, in pence
        dedupe_id (str): unique string used to de-duplicate deposits. Will be
            created if not provided
    """

    dedupe_id = dedupe_id or "|".join(
        [pot.id, str(amount_pence), str(utcnow(DTU.SECOND))],
    )

    res = put(
        f"{self.BASE_URL}/pots/{pot.id}/deposit",
        headers=self.request_headers,
        data={
            "source_account_id": self.current_account.id,
            "amount": amount_pence,
            "dedupe_id": dedupe_id,
        },
        timeout=10,
    )
    res.raise_for_status()

get_pot_by_id(pot_id)

Get a pot from its ID.

Parameters:

Name Type Description Default
pot_id str

the ID of the pot to find

required

Returns:

Name Type Description
Pot Pot | None

the Pot instance

Source code in wg_utilities/clients/monzo.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def get_pot_by_id(self, pot_id: str) -> Pot | None:
    """Get a pot from its ID.

    Args:
        pot_id (str): the ID of the pot to find

    Returns:
        Pot: the Pot instance
    """
    for pot in self.list_pots(include_deleted=True):
        if pot.id == pot_id:
            return pot

    return None

get_pot_by_name(pot_name, *, exact_match=False, include_deleted=False)

Get a pot from its name.

Parameters:

Name Type Description Default
pot_name str

the name of the pot to find

required
exact_match bool

if False, all pot names will be cleansed before evaluation

False
include_deleted bool

whether to include deleted pots in the response

False

Returns:

Name Type Description
Pot Pot | None

the Pot instance

Source code in wg_utilities/clients/monzo.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def get_pot_by_name(
    self,
    pot_name: str,
    *,
    exact_match: bool = False,
    include_deleted: bool = False,
) -> Pot | None:
    """Get a pot from its name.

    Args:
        pot_name (str): the name of the pot to find
        exact_match (bool): if False, all pot names will be cleansed before
            evaluation
        include_deleted (bool): whether to include deleted pots in the response

    Returns:
        Pot: the Pot instance
    """
    if not exact_match:
        pot_name = cleanse_string(pot_name)

    for pot in self.list_pots(include_deleted=include_deleted):
        found_name = pot.name if exact_match else cleanse_string(pot.name)
        if found_name.lower() == pot_name.lower():
            return pot

    return None

list_accounts(*, include_closed=False, account_type=None)

Get a list of the user's accounts.

Parameters:

Name Type Description Default
include_closed bool

whether to include closed accounts in the response

False
account_type str

the type of account(s) to find; submitted as param in request

None

Returns:

Name Type Description
list list[Account]

Account instances, containing all related info

Source code in wg_utilities/clients/monzo.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def list_accounts(
    self,
    *,
    include_closed: bool = False,
    account_type: str | None = None,
) -> list[Account]:
    """Get a list of the user's accounts.

    Args:
        include_closed (bool): whether to include closed accounts in the response
        account_type (str): the type of account(s) to find; submitted as param in
            request

    Returns:
        list: Account instances, containing all related info
    """

    res = self.get_json_response(
        "/accounts",
        params={"account_type": account_type} if account_type else None,
    )

    return [
        Account.from_json_response(account, self)
        for account in res.get("accounts", [])
        if not account.get("closed", True) or include_closed
    ]

list_pots(*, include_deleted=False)

Get a list of the user's pots.

Parameters:

Name Type Description Default
include_deleted bool

whether to include deleted pots in the response

False

Returns:

Name Type Description
list list[Pot]

Pot instances, containing all related info

Source code in wg_utilities/clients/monzo.py
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def list_pots(self, *, include_deleted: bool = False) -> list[Pot]:
    """Get a list of the user's pots.

    Args:
        include_deleted (bool): whether to include deleted pots in the response

    Returns:
        list: Pot instances, containing all related info
    """

    res = self.get_json_response(
        "/pots",
        params={"current_account_id": self.current_account.id},
    )

    return [
        Pot(**pot)
        for pot in res.get("pots", [])
        if not pot.get("deleted", True) or include_deleted
    ]

OAuthClient

Bases: JsonApiClient[GetJsonResponse]

Custom client for interacting with OAuth APIs.

Includes all necessary/basic authentication functionality

Source code in wg_utilities/clients/oauth_client.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
class OAuthClient(JsonApiClient[GetJsonResponse]):
    """Custom client for interacting with OAuth APIs.

    Includes all necessary/basic authentication functionality
    """

    ACCESS_TOKEN_ENDPOINT: str
    AUTH_LINK_BASE: str

    ACCESS_TOKEN_EXPIRY_THRESHOLD = 150

    DEFAULT_CACHE_DIR = getenv("WG_UTILITIES_CREDS_CACHE_DIR")

    DEFAULT_SCOPES: ClassVar[list[str]] = []

    HEADLESS_MODE = getenv("WG_UTILITIES_HEADLESS_MODE", "0") == "1"

    _credentials: OAuthCredentials
    _temp_auth_server: TempAuthServer

    def __init__(  # noqa: PLR0913
        self,
        *,
        client_id: str | None = None,
        client_secret: str | None = None,
        log_requests: bool = False,
        creds_cache_path: Path | None = None,
        creds_cache_dir: Path | None = None,
        scopes: list[str] | None = None,
        oauth_login_redirect_host: str = "localhost",
        oauth_redirect_uri_override: str | None = None,
        headless_auth_link_callback: Callable[[str], None] | None = None,
        use_existing_credentials_only: bool = False,
        access_token_endpoint: str | None = None,
        auth_link_base: str | None = None,
        base_url: str | None = None,
        validate_request_success: bool = True,
    ):
        """Initialise the client.

        Args:
            client_id (str, optional): the client ID for the API. Defaults to None.
            client_secret (str, optional): the client secret for the API. Defaults to
                None.
            log_requests (bool, optional): whether to log requests. Defaults to False.
            creds_cache_path (Path, optional): the path to the credentials cache file.
                Defaults to None. Overrides `creds_cache_dir`.
            creds_cache_dir (Path, optional): the path to the credentials cache directory.
                Overrides environment variable `WG_UTILITIES_CREDS_CACHE_DIR`. Defaults to
                None.
            scopes (list[str], optional): the scopes to request when authenticating.
                Defaults to None.
            oauth_login_redirect_host (str, optional): the host to redirect to after
                authenticating. Defaults to "localhost".
            oauth_redirect_uri_override (str, optional): override the redirect URI
                specified in the OAuth flow. Defaults to None.
            headless_auth_link_callback (Callable[[str], None], optional): callback to
                send the auth link to when running in headless mode. Defaults to None.
            use_existing_credentials_only (bool, optional): whether to only use existing
                credentials, and not attempt to authenticate. Defaults to False.
            access_token_endpoint (str, optional): the endpoint to use to get an access
                token. Defaults to None.
            auth_link_base (str, optional): the base URL to use to generate the auth
                link. Defaults to None.
            base_url (str, optional): the base URL to use for API requests. Defaults to
                None.
            validate_request_success (bool, optional): whether to validate that the
                request was successful. Defaults to True.
        """
        super().__init__(
            log_requests=log_requests,
            base_url=base_url,
            validate_request_success=validate_request_success,
        )

        self._client_id = client_id
        self._client_secret = client_secret
        self.access_token_endpoint = access_token_endpoint or self.ACCESS_TOKEN_ENDPOINT
        self.auth_link_base = auth_link_base or self.AUTH_LINK_BASE
        self.oauth_login_redirect_host = oauth_login_redirect_host
        self.oauth_redirect_uri_override = oauth_redirect_uri_override
        self.headless_auth_link_callback = headless_auth_link_callback
        self.use_existing_credentials_only = use_existing_credentials_only

        if creds_cache_path:
            self._creds_cache_path: Path | None = creds_cache_path
            self._creds_cache_dir: Path | None = None
        elif creds_cache_dir:
            self._creds_cache_path = None
            self._creds_cache_dir = creds_cache_dir
        else:
            self._creds_cache_path = None
            if self.DEFAULT_CACHE_DIR:
                self._creds_cache_dir = Path(self.DEFAULT_CACHE_DIR)
            else:
                self._creds_cache_dir = None

        self.scopes = scopes or self.DEFAULT_SCOPES

        if self._creds_cache_path:
            self._load_local_credentials()

    def _load_local_credentials(self) -> bool:
        """Load credentials from the local cache.

        Returns:
            bool: True if the credentials were loaded successfully, False otherwise
        """
        try:
            self._credentials = OAuthCredentials.model_validate_json(
                self.creds_cache_path.read_text(),
            )
        except FileNotFoundError:
            return False

        return True

    def delete_creds_file(self) -> None:
        """Delete the local creds file."""
        self.creds_cache_path.unlink(missing_ok=True)

    def refresh_access_token(self) -> None:
        """Refresh access token."""

        if not hasattr(self, "_credentials") and not self._load_local_credentials():
            # If we don't have any credentials, we can't refresh the access token -
            # perform first time login and leave it at that
            self.run_first_time_login()
            return

        LOGGER.info("Refreshing access token")

        payload = {
            "grant_type": "refresh_token",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "refresh_token": self.credentials.refresh_token,
        }

        new_creds = self.post_json_response(
            self.access_token_endpoint,
            data=payload,
            header_overrides={},
        )

        self.credentials.update_access_token(
            new_token=new_creds["access_token"],
            expires_in=new_creds["expires_in"],
            # Monzo
            refresh_token=new_creds.get("refresh_token"),
        )

        self.creds_cache_path.write_text(
            self.credentials.model_dump_json(exclude_none=True),
        )

    def run_first_time_login(self) -> None:
        """Run the first time login process.

        This is a blocking call which will not return until the user has
        authenticated with the OAuth provider.

        Raises:
            RuntimeError: if `use_existing_credentials_only` is set to True
            ValueError: if the state token returned by the OAuth provider does not
                match
        """

        if self.use_existing_credentials_only:
            raise RuntimeError(
                "No existing credentials found, and `use_existing_credentials_only` "
                "is set to True",
            )

        LOGGER.info("Performing first time login")

        state_token = "".join(choice(ascii_letters) for _ in range(32))  # noqa: S311

        self.temp_auth_server.start_server()

        if self.oauth_redirect_uri_override:
            redirect_uri = self.oauth_redirect_uri_override
        else:
            redirect_uri = f"http://{self.oauth_login_redirect_host}:{self.temp_auth_server.port}/get_auth_code"

        auth_link_params = {
            "client_id": self._client_id,
            "redirect_uri": redirect_uri,
            "response_type": "code",
            "state": state_token,
            "access_type": "offline",
            "prompt": "consent",
        }

        if self.scopes:
            auth_link_params["scope"] = " ".join(self.scopes)

        auth_link = self.auth_link_base + "?" + urlencode(auth_link_params)

        if self.HEADLESS_MODE:
            if self.headless_auth_link_callback is None:
                LOGGER.warning(
                    "Headless mode is enabled, but no headless auth link callback "
                    "has been set. The auth link will not be opened.",
                )
                LOGGER.debug("Auth link: %s", auth_link)
            else:
                LOGGER.info("Sending auth link to callback")
                self.headless_auth_link_callback(auth_link)
        else:
            open_browser(auth_link)

        request_args = self.temp_auth_server.wait_for_request(
            "/get_auth_code",
            kill_on_request=True,
        )

        if state_token != request_args.get("state"):
            raise ValueError(
                "State token received in request doesn't match expected value: "
                f"`{request_args.get('state')}` != `{state_token}`",
            )

        payload_key = (
            "data"
            if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
            else "json"
        )

        res = self._post(
            self.access_token_endpoint,
            **{  # type: ignore[arg-type]
                payload_key: {
                    "code": request_args["code"],
                    "grant_type": "authorization_code",
                    "client_id": self._client_id,
                    "client_secret": self._client_secret,
                    "redirect_uri": redirect_uri,
                },
            },
            # Stops recursive call to `self.request_headers`
            header_overrides=(
                {"Content-Type": "application/x-www-form-urlencoded"}
                if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
                else {}
            ),
        )

        credentials = res.json()

        if self._client_id:
            credentials["client_id"] = self._client_id

        if self._client_secret:
            credentials["client_secret"] = self._client_secret

        self.credentials = OAuthCredentials.parse_first_time_login(credentials)

    @property
    def _creds_rel_file_path(self) -> Path | None:
        """Get the credentials cache filepath relative to the cache directory.

        Overridable in subclasses.
        """

        try:
            client_id = self._client_id or self._credentials.client_id
        except AttributeError:
            return None

        return Path(type(self).__name__, f"{client_id}.json")

    @property
    def access_token(self) -> str | None:
        """Access token.

        Returns:
            str: the access token for this bank's API
        """
        if self.access_token_has_expired:
            self.refresh_access_token()

        return self.credentials.access_token

    @property
    def access_token_has_expired(self) -> bool:
        """Decode the JWT access token and evaluates the expiry time.

        Returns:
            bool: has the access token expired?
        """
        if not hasattr(self, "_credentials") and not self._load_local_credentials():
            return True

        return (
            self.credentials.expiry_epoch
            < int(time()) + self.ACCESS_TOKEN_EXPIRY_THRESHOLD
        )

    @property
    def client_id(self) -> str:
        """Client ID for the Google API.

        Returns:
            str: the current client ID
        """

        return self._client_id or self.credentials.client_id

    @property
    def client_secret(self) -> str | None:
        """Client secret.

        Returns:
            str: the current client secret
        """

        return self._client_secret or self.credentials.client_secret

    @property
    def credentials(self) -> OAuthCredentials:
        """Get creds as necessary (including first time setup) and authenticates them.

        Returns:
            OAuthCredentials: the credentials for the chosen bank

        Raises:
            ValueError: if the state token returned from the request doesn't match the
                expected value
        """
        if not hasattr(self, "_credentials") and not self._load_local_credentials():
            self.run_first_time_login()

        return self._credentials

    @credentials.setter
    def credentials(self, value: OAuthCredentials) -> None:
        """Set the client's credentials, and write to the local cache file."""

        self._credentials = value

        self.creds_cache_path.write_text(
            dumps(self._credentials.model_dump(exclude_none=True)),
        )

    @property
    def creds_cache_path(self) -> Path:
        """Path to the credentials cache file.

        Returns:
            Path: the path to the credentials cache file

        Raises:
            ValueError: if the path to the credentials cache file is not set, and can't
                be generated due to a lack of client ID
        """
        if self._creds_cache_path:
            return self._creds_cache_path

        if not self._creds_rel_file_path:
            raise ValueError(
                "Unable to get client ID to generate path for credentials cache file.",
            )

        return force_mkdir(
            (self._creds_cache_dir or user_data_dir() / "oauth_credentials")
            / self._creds_rel_file_path,
            path_is_file=True,
        )

    @property
    def request_headers(self) -> dict[str, str]:
        """Header to be used in requests to the API.

        Returns:
            dict: auth headers for HTTP requests
        """
        return {
            "Authorization": f"Bearer {self.access_token}",
            "Content-Type": "application/json",
        }

    @property
    def refresh_token(self) -> str:
        """Refresh token.

        Returns:
            str: the API refresh token
        """
        return self.credentials.refresh_token

    @property
    def temp_auth_server(self) -> TempAuthServer:
        """Create a temporary HTTP server for the auth flow.

        Returns:
            TempAuthServer: the temporary server
        """
        if not hasattr(self, "_temp_auth_server"):
            self._temp_auth_server = TempAuthServer(__name__, auto_run=False)

        return self._temp_auth_server

access_token: str | None property

Access token.

Returns:

Name Type Description
str str | None

the access token for this bank's API

access_token_has_expired: bool property

Decode the JWT access token and evaluates the expiry time.

Returns:

Name Type Description
bool bool

has the access token expired?

client_id: str property

Client ID for the Google API.

Returns:

Name Type Description
str str

the current client ID

client_secret: str | None property

Client secret.

Returns:

Name Type Description
str str | None

the current client secret

credentials: OAuthCredentials property writable

Get creds as necessary (including first time setup) and authenticates them.

Returns:

Name Type Description
OAuthCredentials OAuthCredentials

the credentials for the chosen bank

Raises:

Type Description
ValueError

if the state token returned from the request doesn't match the expected value

creds_cache_path: Path property

Path to the credentials cache file.

Returns:

Name Type Description
Path Path

the path to the credentials cache file

Raises:

Type Description
ValueError

if the path to the credentials cache file is not set, and can't be generated due to a lack of client ID

refresh_token: str property

Refresh token.

Returns:

Name Type Description
str str

the API refresh token

request_headers: dict[str, str] property

Header to be used in requests to the API.

Returns:

Name Type Description
dict dict[str, str]

auth headers for HTTP requests

temp_auth_server: TempAuthServer property

Create a temporary HTTP server for the auth flow.

Returns:

Name Type Description
TempAuthServer TempAuthServer

the temporary server

__init__(*, client_id=None, client_secret=None, log_requests=False, creds_cache_path=None, creds_cache_dir=None, scopes=None, oauth_login_redirect_host='localhost', oauth_redirect_uri_override=None, headless_auth_link_callback=None, use_existing_credentials_only=False, access_token_endpoint=None, auth_link_base=None, base_url=None, validate_request_success=True)

Initialise the client.

Parameters:

Name Type Description Default
client_id str

the client ID for the API. Defaults to None.

None
client_secret str

the client secret for the API. Defaults to None.

None
log_requests bool

whether to log requests. Defaults to False.

False
creds_cache_path Path

the path to the credentials cache file. Defaults to None. Overrides creds_cache_dir.

None
creds_cache_dir Path

the path to the credentials cache directory. Overrides environment variable WG_UTILITIES_CREDS_CACHE_DIR. Defaults to None.

None
scopes list[str]

the scopes to request when authenticating. Defaults to None.

None
oauth_login_redirect_host str

the host to redirect to after authenticating. Defaults to "localhost".

'localhost'
oauth_redirect_uri_override str

override the redirect URI specified in the OAuth flow. Defaults to None.

None
headless_auth_link_callback Callable[[str], None]

callback to send the auth link to when running in headless mode. Defaults to None.

None
use_existing_credentials_only bool

whether to only use existing credentials, and not attempt to authenticate. Defaults to False.

False
access_token_endpoint str

the endpoint to use to get an access token. Defaults to None.

None
auth_link_base str

the base URL to use to generate the auth link. Defaults to None.

None
base_url str

the base URL to use for API requests. Defaults to None.

None
validate_request_success bool

whether to validate that the request was successful. Defaults to True.

True
Source code in wg_utilities/clients/oauth_client.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def __init__(  # noqa: PLR0913
    self,
    *,
    client_id: str | None = None,
    client_secret: str | None = None,
    log_requests: bool = False,
    creds_cache_path: Path | None = None,
    creds_cache_dir: Path | None = None,
    scopes: list[str] | None = None,
    oauth_login_redirect_host: str = "localhost",
    oauth_redirect_uri_override: str | None = None,
    headless_auth_link_callback: Callable[[str], None] | None = None,
    use_existing_credentials_only: bool = False,
    access_token_endpoint: str | None = None,
    auth_link_base: str | None = None,
    base_url: str | None = None,
    validate_request_success: bool = True,
):
    """Initialise the client.

    Args:
        client_id (str, optional): the client ID for the API. Defaults to None.
        client_secret (str, optional): the client secret for the API. Defaults to
            None.
        log_requests (bool, optional): whether to log requests. Defaults to False.
        creds_cache_path (Path, optional): the path to the credentials cache file.
            Defaults to None. Overrides `creds_cache_dir`.
        creds_cache_dir (Path, optional): the path to the credentials cache directory.
            Overrides environment variable `WG_UTILITIES_CREDS_CACHE_DIR`. Defaults to
            None.
        scopes (list[str], optional): the scopes to request when authenticating.
            Defaults to None.
        oauth_login_redirect_host (str, optional): the host to redirect to after
            authenticating. Defaults to "localhost".
        oauth_redirect_uri_override (str, optional): override the redirect URI
            specified in the OAuth flow. Defaults to None.
        headless_auth_link_callback (Callable[[str], None], optional): callback to
            send the auth link to when running in headless mode. Defaults to None.
        use_existing_credentials_only (bool, optional): whether to only use existing
            credentials, and not attempt to authenticate. Defaults to False.
        access_token_endpoint (str, optional): the endpoint to use to get an access
            token. Defaults to None.
        auth_link_base (str, optional): the base URL to use to generate the auth
            link. Defaults to None.
        base_url (str, optional): the base URL to use for API requests. Defaults to
            None.
        validate_request_success (bool, optional): whether to validate that the
            request was successful. Defaults to True.
    """
    super().__init__(
        log_requests=log_requests,
        base_url=base_url,
        validate_request_success=validate_request_success,
    )

    self._client_id = client_id
    self._client_secret = client_secret
    self.access_token_endpoint = access_token_endpoint or self.ACCESS_TOKEN_ENDPOINT
    self.auth_link_base = auth_link_base or self.AUTH_LINK_BASE
    self.oauth_login_redirect_host = oauth_login_redirect_host
    self.oauth_redirect_uri_override = oauth_redirect_uri_override
    self.headless_auth_link_callback = headless_auth_link_callback
    self.use_existing_credentials_only = use_existing_credentials_only

    if creds_cache_path:
        self._creds_cache_path: Path | None = creds_cache_path
        self._creds_cache_dir: Path | None = None
    elif creds_cache_dir:
        self._creds_cache_path = None
        self._creds_cache_dir = creds_cache_dir
    else:
        self._creds_cache_path = None
        if self.DEFAULT_CACHE_DIR:
            self._creds_cache_dir = Path(self.DEFAULT_CACHE_DIR)
        else:
            self._creds_cache_dir = None

    self.scopes = scopes or self.DEFAULT_SCOPES

    if self._creds_cache_path:
        self._load_local_credentials()

delete_creds_file()

Delete the local creds file.

Source code in wg_utilities/clients/oauth_client.py
357
358
359
def delete_creds_file(self) -> None:
    """Delete the local creds file."""
    self.creds_cache_path.unlink(missing_ok=True)

refresh_access_token()

Refresh access token.

Source code in wg_utilities/clients/oauth_client.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def refresh_access_token(self) -> None:
    """Refresh access token."""

    if not hasattr(self, "_credentials") and not self._load_local_credentials():
        # If we don't have any credentials, we can't refresh the access token -
        # perform first time login and leave it at that
        self.run_first_time_login()
        return

    LOGGER.info("Refreshing access token")

    payload = {
        "grant_type": "refresh_token",
        "client_id": self.client_id,
        "client_secret": self.client_secret,
        "refresh_token": self.credentials.refresh_token,
    }

    new_creds = self.post_json_response(
        self.access_token_endpoint,
        data=payload,
        header_overrides={},
    )

    self.credentials.update_access_token(
        new_token=new_creds["access_token"],
        expires_in=new_creds["expires_in"],
        # Monzo
        refresh_token=new_creds.get("refresh_token"),
    )

    self.creds_cache_path.write_text(
        self.credentials.model_dump_json(exclude_none=True),
    )

run_first_time_login()

Run the first time login process.

This is a blocking call which will not return until the user has authenticated with the OAuth provider.

Raises:

Type Description
RuntimeError

if use_existing_credentials_only is set to True

ValueError

if the state token returned by the OAuth provider does not match

Source code in wg_utilities/clients/oauth_client.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def run_first_time_login(self) -> None:
    """Run the first time login process.

    This is a blocking call which will not return until the user has
    authenticated with the OAuth provider.

    Raises:
        RuntimeError: if `use_existing_credentials_only` is set to True
        ValueError: if the state token returned by the OAuth provider does not
            match
    """

    if self.use_existing_credentials_only:
        raise RuntimeError(
            "No existing credentials found, and `use_existing_credentials_only` "
            "is set to True",
        )

    LOGGER.info("Performing first time login")

    state_token = "".join(choice(ascii_letters) for _ in range(32))  # noqa: S311

    self.temp_auth_server.start_server()

    if self.oauth_redirect_uri_override:
        redirect_uri = self.oauth_redirect_uri_override
    else:
        redirect_uri = f"http://{self.oauth_login_redirect_host}:{self.temp_auth_server.port}/get_auth_code"

    auth_link_params = {
        "client_id": self._client_id,
        "redirect_uri": redirect_uri,
        "response_type": "code",
        "state": state_token,
        "access_type": "offline",
        "prompt": "consent",
    }

    if self.scopes:
        auth_link_params["scope"] = " ".join(self.scopes)

    auth_link = self.auth_link_base + "?" + urlencode(auth_link_params)

    if self.HEADLESS_MODE:
        if self.headless_auth_link_callback is None:
            LOGGER.warning(
                "Headless mode is enabled, but no headless auth link callback "
                "has been set. The auth link will not be opened.",
            )
            LOGGER.debug("Auth link: %s", auth_link)
        else:
            LOGGER.info("Sending auth link to callback")
            self.headless_auth_link_callback(auth_link)
    else:
        open_browser(auth_link)

    request_args = self.temp_auth_server.wait_for_request(
        "/get_auth_code",
        kill_on_request=True,
    )

    if state_token != request_args.get("state"):
        raise ValueError(
            "State token received in request doesn't match expected value: "
            f"`{request_args.get('state')}` != `{state_token}`",
        )

    payload_key = (
        "data"
        if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
        else "json"
    )

    res = self._post(
        self.access_token_endpoint,
        **{  # type: ignore[arg-type]
            payload_key: {
                "code": request_args["code"],
                "grant_type": "authorization_code",
                "client_id": self._client_id,
                "client_secret": self._client_secret,
                "redirect_uri": redirect_uri,
            },
        },
        # Stops recursive call to `self.request_headers`
        header_overrides=(
            {"Content-Type": "application/x-www-form-urlencoded"}
            if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
            else {}
        ),
    )

    credentials = res.json()

    if self._client_id:
        credentials["client_id"] = self._client_id

    if self._client_secret:
        credentials["client_secret"] = self._client_secret

    self.credentials = OAuthCredentials.parse_first_time_login(credentials)

SpotifyClient

Bases: OAuthClient[SpotifyEntityJson]

Custom client for interacting with Spotify's Web API.

For authentication purposes either an already-instantiated OAuth manager or the relevant credentials must be provided

Parameters:

Name Type Description Default
client_id str

the application's client ID

None
client_secret str

the application's client secret

None
log_requests bool

flag for choosing if to log all requests made

False
creds_cache_path str

path at which to save cached credentials

None
Source code in wg_utilities/clients/spotify.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
class SpotifyClient(OAuthClient[SpotifyEntityJson]):
    """Custom client for interacting with Spotify's Web API.

    For authentication purposes either an already-instantiated OAuth manager or the
    relevant credentials must be provided

    Args:
        client_id (str): the application's client ID
        client_secret (str): the application's client secret
        log_requests (bool): flag for choosing if to log all requests made
        creds_cache_path (str): path at which to save cached credentials
    """

    AUTH_LINK_BASE = "https://accounts.spotify.com/authorize"
    ACCESS_TOKEN_ENDPOINT = "https://accounts.spotify.com/api/token"  # noqa: S105
    BASE_URL = "https://api.spotify.com/v1"

    DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "ugc-image-upload",
        "user-read-recently-played",
        "user-top-read",
        "user-read-playback-position",
        "user-read-playback-state",
        "user-modify-playback-state",
        "user-read-currently-playing",
        "app-remote-control",
        "streaming",
        "playlist-modify-public",
        "playlist-modify-private",
        "playlist-read-private",
        "playlist-read-collaborative",
        "user-follow-modify",
        "user-follow-read",
        "user-library-modify",
        "user-library-read",
        "user-read-email",
        "user-read-private",
    ]

    SEARCH_TYPES: tuple[Literal["album", "artist", "playlist", "track"], ...] = (
        "album",
        "artist",
        "playlist",
        "track",
        # "show",
        # "episode",
    )

    _current_user: User

    def add_tracks_to_playlist(
        self,
        tracks: Iterable[Track],
        playlist: Playlist,
        *,
        log_responses: bool = False,
        force_add: bool = False,
        update_instance_tracklist: bool = True,
    ) -> list[Track]:
        """Add one or more tracks to a playlist.

        If `force_add` is False, a check is made against the Playlist's tracklist to
        ensure that the track isn't already in the playlist. This can be slow if it's
        a (new) big playlist.

        Args:
            tracks (list): a list of Track instances to be added to the given playlist
            playlist (Playlist): the playlist being updated
            log_responses (bool): log each individual response
            force_add (bool): flag for adding the track even if it's in the playlist
                already
            update_instance_tracklist (bool): appends the track to the Playlist's
                tracklist. Can be slow if it's a big playlist as it might have to get
                the tracklist first
        """

        tracks_to_add = [
            track
            for track in tracks
            if not track.is_local and (force_add or track not in playlist)
        ]

        for chunk in chunk_list(tracks_to_add, 100):
            res = self._post(
                f"/playlists/{playlist.id}/tracks",
                json={"uris": [t.uri for t in chunk]},
            )

            if log_responses:
                LOGGER.info(dumps(res.json()))

        if update_instance_tracklist:
            playlist.tracks.extend(tracks_to_add)

        return tracks_to_add

    def create_playlist(
        self,
        *,
        name: str,
        description: str = "",
        public: bool = False,
        collaborative: bool = False,
    ) -> Playlist:
        """Create a new playlist under the current user's account.

        Args:
            name (str): the name of the new playlist
            description (str): the description of the new playlist
            public (bool): determines if the playlist is publicly accessible
            collaborative (bool): allows other people to add tracks when True

        Returns:
            Playlist: an instance of the Playlist class containing the new playlist's
                metadata
        """
        res = self._post(
            f"/users/{self.current_user.id}/playlists",
            json={
                "name": name,
                "description": description,
                "public": public,
                "collaborative": collaborative,
            },
        )

        return Playlist.from_json_response(res.json(), spotify_client=self)

    def get_album_by_id(self, id_: str) -> Album:
        """Get an album from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the album

        Returns:
            Album: an instantiated Album, from the API's response
        """

        return Album.from_json_response(
            self.get_json_response(f"/albums/{id_}"),
            spotify_client=self,
        )

    def get_artist_by_id(self, id_: str) -> Artist:
        """Get an artist from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the artist

        Returns:
            Artist: an instantiated Artist, from the API's response
        """

        return Artist.from_json_response(
            self.get_json_response(f"/artists/{id_}"),
            spotify_client=self,
        )

    def get_items(
        self,
        url: str,
        *,
        params: None | dict[str, str | int | float | bool | dict[str, Any]] = None,
        hard_limit: int = 1000000,
        limit_func: (
            Callable[
                [dict[str, Any] | SpotifyEntityJson],
                bool,
            ]
            | None
        ) = None,
        top_level_key: (
            Literal[
                "albums",
                "artists",
                "audiobooks",
                "episodes",
                "playlists",
                "shows",
                "tracks",
            ]
            | None
        ) = None,
        list_key: Literal["items", "devices"] = "items",
    ) -> list[SpotifyEntityJson]:
        """Retrieve a list of items from a given URL, including pagination.

        Args:
            url (str): the API endpoint which we're listing
            params (dict): any params to pass with the API request
            hard_limit (int): a hard limit to apply to the number of items returned (as
                opposed to the "soft" limit of 50 imposed by the API)
            limit_func (Callable): a function which is used to evaluate each item in
                turn: if it returns False, the item is added to the output list; if it
                returns True then the iteration stops and the list is returned as-is
            top_level_key (str): an optional key to use when the items in the response
                are nested 1 level deeper than normal
            list_key (Literal["devices", "items"]): the key in the response which
                contains the list of items

        Returns:
            list: a list of dicts representing the Spotify items
        """

        params = params or {}
        if "limit=" not in url:
            params["limit"] = min(50, hard_limit)

        items: list[SpotifyEntityJson] = []

        if params:
            url += ("?" if "?" not in url else "&") + urlencode(params)

        page: AnyPaginatedResponse = {
            "href": "",
            "items": [],
            "limit": 0,
            "next": url,
            "offset": 0,
            "total": 0,
        }

        while (next_url := page.get("next")) and len(items) < hard_limit:
            # Ensure we don't bother getting more items than we need
            limit = min(50, hard_limit - len(items))
            next_url = sub(r"(?<=limit=)(\d{1,2})(?=&?)", str(limit), next_url)

            res: SearchResponse | AnyPaginatedResponse = self.get_json_response(next_url)  # type: ignore[assignment]
            page = (
                cast(SearchResponse, res)[top_level_key]
                if top_level_key
                else cast(AnyPaginatedResponse, res)
            )

            page_items: (
                list[AlbumSummaryJson]
                | list[DeviceJson]
                | list[ArtistSummaryJson]
                | list[PlaylistSummaryJson]
                | list[TrackFullJson]
            ) = page.get(list_key, [])
            if limit_func is None:
                items.extend(page_items)
            else:
                # Initialise `limit_reached` to False, and then it will be set to
                # True on the first matching item. This will then cause the loop to
                # skip subsequent items - not as good as a `break` but still kind of
                # elegant imho!
                limit_reached = False
                items.extend(
                    [
                        item
                        for item in page_items
                        if (not (limit_reached := (limit_reached or limit_func(item))))
                    ],
                )
                if limit_reached:
                    return items

        return items

    def get_playlist_by_id(self, id_: str) -> Playlist:
        """Get a playlist from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the playlist

        Returns:
            Playlist: an instantiated Playlist, from the API's response
        """

        if hasattr(self, "_current_user") and hasattr(self.current_user, "_playlists"):
            for playlist in self.current_user.playlists:
                if playlist.id == id_:
                    return playlist

        return Playlist.from_json_response(
            self.get_json_response(f"/playlists/{id_}"),
            spotify_client=self,
        )

    def get_track_by_id(self, id_: str) -> Track:
        """Get a track from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the track

        Returns:
            Track: an instantiated Track, from the API's response
        """

        return Track.from_json_response(
            self.get_json_response(f"/tracks/{id_}"),
            spotify_client=self,
        )

    @overload
    def search(
        self,
        search_term: str,
        *,
        entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
        get_best_match_only: Literal[True],
    ) -> Artist | Playlist | Track | Album | None: ...

    @overload
    def search(
        self,
        search_term: str,
        *,
        entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
        get_best_match_only: Literal[False] = False,
    ) -> ParsedSearchResponse: ...

    def search(
        self,
        search_term: str,
        *,
        entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
        get_best_match_only: bool = False,
    ) -> Artist | Playlist | Track | Album | None | ParsedSearchResponse:
        """Search Spotify for a given search term.

        Args:
            search_term (str): the term to use as the base of the search
            entity_types (str): the types of entity to search for. Must be one of
                SpotifyClient.SEARCH_TYPES
            get_best_match_only (bool): return a single entity from the top of the
                list, rather than all matches

        Returns:
            Artist | Playlist | Track | Album: a single entity if the best match flag
                is set
            dict: a dict of entities, by type

        Raises:
            ValueError: if multiple entity types have been requested but the best match
                flag is true
            ValueError: if one of entity_types is an invalid value
        """

        entity_types = entity_types or self.SEARCH_TYPES

        if get_best_match_only is True and len(entity_types) != 1:
            raise ValueError(
                "Exactly one entity type must be requested if `get_best_match_only`"
                " is True",
            )

        entity_type: Literal["artist", "playlist", "track", "album"]
        for entity_type in entity_types:
            if entity_type not in self.SEARCH_TYPES:
                raise ValueError(
                    f"Unexpected value for entity type: '{entity_type}'. Must be"
                    f" one of {self.SEARCH_TYPES!r}",
                )

        res: SearchResponse = self.get_json_response(  # type: ignore[assignment]
            "/search",
            params={
                "query": search_term,
                "type": ",".join(entity_types),
                "limit": 1 if get_best_match_only else 50,
            },
        )

        entity_instances: ParsedSearchResponse = {}

        res_entity_type: Literal["albums", "artists", "playlists", "tracks"]
        entities_json: (
            PaginatedResponseAlbums
            | PaginatedResponseArtists
            | PaginatedResponsePlaylists
            | PaginatedResponseTracks
        )
        for res_entity_type, entities_json in res.items():  # type: ignore[assignment]
            instance_class: type[Album] | type[Artist] | type[Playlist] | type[Track] = {  # type: ignore[assignment]
                "albums": Album,
                "artists": Artist,
                "playlists": Playlist,
                "tracks": Track,
            }[res_entity_type]

            if get_best_match_only:
                try:
                    # Take the entity off the top of the list
                    return instance_class.from_json_response(
                        entities_json["items"][0],
                        spotify_client=self,
                    )
                except LookupError:
                    return None

            entity_instances.setdefault(res_entity_type, []).extend(
                [
                    instance_class.from_json_response(entity_json, spotify_client=self)  # type: ignore[misc]
                    for entity_json in entities_json.get("items", [])
                ],
            )

            # Each entity type has its own type-specific next URL
            if (next_url := entities_json.get("next")) is not None:
                entity_instances[res_entity_type].extend(
                    [
                        instance_class.from_json_response(  # type: ignore[misc]
                            item,
                            spotify_client=self,
                        )
                        for item in self.get_items(
                            next_url,
                            top_level_key=res_entity_type,
                        )
                    ],
                )

        return entity_instances

    @property
    def current_user(self) -> User:
        """Get the current user's info.

        Returns:
            User: an instance of the current Spotify user
        """
        if not hasattr(self, "_current_user"):
            self._current_user = User.from_json_response(
                self.get_json_response("/me"),
                spotify_client=self,
            )

        return self._current_user

current_user: User property

Get the current user's info.

Returns:

Name Type Description
User User

an instance of the current Spotify user

add_tracks_to_playlist(tracks, playlist, *, log_responses=False, force_add=False, update_instance_tracklist=True)

Add one or more tracks to a playlist.

If force_add is False, a check is made against the Playlist's tracklist to ensure that the track isn't already in the playlist. This can be slow if it's a (new) big playlist.

Parameters:

Name Type Description Default
tracks list

a list of Track instances to be added to the given playlist

required
playlist Playlist

the playlist being updated

required
log_responses bool

log each individual response

False
force_add bool

flag for adding the track even if it's in the playlist already

False
update_instance_tracklist bool

appends the track to the Playlist's tracklist. Can be slow if it's a big playlist as it might have to get the tracklist first

True
Source code in wg_utilities/clients/spotify.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def add_tracks_to_playlist(
    self,
    tracks: Iterable[Track],
    playlist: Playlist,
    *,
    log_responses: bool = False,
    force_add: bool = False,
    update_instance_tracklist: bool = True,
) -> list[Track]:
    """Add one or more tracks to a playlist.

    If `force_add` is False, a check is made against the Playlist's tracklist to
    ensure that the track isn't already in the playlist. This can be slow if it's
    a (new) big playlist.

    Args:
        tracks (list): a list of Track instances to be added to the given playlist
        playlist (Playlist): the playlist being updated
        log_responses (bool): log each individual response
        force_add (bool): flag for adding the track even if it's in the playlist
            already
        update_instance_tracklist (bool): appends the track to the Playlist's
            tracklist. Can be slow if it's a big playlist as it might have to get
            the tracklist first
    """

    tracks_to_add = [
        track
        for track in tracks
        if not track.is_local and (force_add or track not in playlist)
    ]

    for chunk in chunk_list(tracks_to_add, 100):
        res = self._post(
            f"/playlists/{playlist.id}/tracks",
            json={"uris": [t.uri for t in chunk]},
        )

        if log_responses:
            LOGGER.info(dumps(res.json()))

    if update_instance_tracklist:
        playlist.tracks.extend(tracks_to_add)

    return tracks_to_add

create_playlist(*, name, description='', public=False, collaborative=False)

Create a new playlist under the current user's account.

Parameters:

Name Type Description Default
name str

the name of the new playlist

required
description str

the description of the new playlist

''
public bool

determines if the playlist is publicly accessible

False
collaborative bool

allows other people to add tracks when True

False

Returns:

Name Type Description
Playlist Playlist

an instance of the Playlist class containing the new playlist's metadata

Source code in wg_utilities/clients/spotify.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def create_playlist(
    self,
    *,
    name: str,
    description: str = "",
    public: bool = False,
    collaborative: bool = False,
) -> Playlist:
    """Create a new playlist under the current user's account.

    Args:
        name (str): the name of the new playlist
        description (str): the description of the new playlist
        public (bool): determines if the playlist is publicly accessible
        collaborative (bool): allows other people to add tracks when True

    Returns:
        Playlist: an instance of the Playlist class containing the new playlist's
            metadata
    """
    res = self._post(
        f"/users/{self.current_user.id}/playlists",
        json={
            "name": name,
            "description": description,
            "public": public,
            "collaborative": collaborative,
        },
    )

    return Playlist.from_json_response(res.json(), spotify_client=self)

get_album_by_id(id_)

Get an album from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the album

required

Returns:

Name Type Description
Album Album

an instantiated Album, from the API's response

Source code in wg_utilities/clients/spotify.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def get_album_by_id(self, id_: str) -> Album:
    """Get an album from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the album

    Returns:
        Album: an instantiated Album, from the API's response
    """

    return Album.from_json_response(
        self.get_json_response(f"/albums/{id_}"),
        spotify_client=self,
    )

get_artist_by_id(id_)

Get an artist from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the artist

required

Returns:

Name Type Description
Artist Artist

an instantiated Artist, from the API's response

Source code in wg_utilities/clients/spotify.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def get_artist_by_id(self, id_: str) -> Artist:
    """Get an artist from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the artist

    Returns:
        Artist: an instantiated Artist, from the API's response
    """

    return Artist.from_json_response(
        self.get_json_response(f"/artists/{id_}"),
        spotify_client=self,
    )

get_items(url, *, params=None, hard_limit=1000000, limit_func=None, top_level_key=None, list_key='items')

Retrieve a list of items from a given URL, including pagination.

Parameters:

Name Type Description Default
url str

the API endpoint which we're listing

required
params dict

any params to pass with the API request

None
hard_limit int

a hard limit to apply to the number of items returned (as opposed to the "soft" limit of 50 imposed by the API)

1000000
limit_func Callable

a function which is used to evaluate each item in turn: if it returns False, the item is added to the output list; if it returns True then the iteration stops and the list is returned as-is

None
top_level_key str

an optional key to use when the items in the response are nested 1 level deeper than normal

None
list_key Literal['devices', 'items']

the key in the response which contains the list of items

'items'

Returns:

Name Type Description
list list[SpotifyEntityJson]

a list of dicts representing the Spotify items

Source code in wg_utilities/clients/spotify.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def get_items(
    self,
    url: str,
    *,
    params: None | dict[str, str | int | float | bool | dict[str, Any]] = None,
    hard_limit: int = 1000000,
    limit_func: (
        Callable[
            [dict[str, Any] | SpotifyEntityJson],
            bool,
        ]
        | None
    ) = None,
    top_level_key: (
        Literal[
            "albums",
            "artists",
            "audiobooks",
            "episodes",
            "playlists",
            "shows",
            "tracks",
        ]
        | None
    ) = None,
    list_key: Literal["items", "devices"] = "items",
) -> list[SpotifyEntityJson]:
    """Retrieve a list of items from a given URL, including pagination.

    Args:
        url (str): the API endpoint which we're listing
        params (dict): any params to pass with the API request
        hard_limit (int): a hard limit to apply to the number of items returned (as
            opposed to the "soft" limit of 50 imposed by the API)
        limit_func (Callable): a function which is used to evaluate each item in
            turn: if it returns False, the item is added to the output list; if it
            returns True then the iteration stops and the list is returned as-is
        top_level_key (str): an optional key to use when the items in the response
            are nested 1 level deeper than normal
        list_key (Literal["devices", "items"]): the key in the response which
            contains the list of items

    Returns:
        list: a list of dicts representing the Spotify items
    """

    params = params or {}
    if "limit=" not in url:
        params["limit"] = min(50, hard_limit)

    items: list[SpotifyEntityJson] = []

    if params:
        url += ("?" if "?" not in url else "&") + urlencode(params)

    page: AnyPaginatedResponse = {
        "href": "",
        "items": [],
        "limit": 0,
        "next": url,
        "offset": 0,
        "total": 0,
    }

    while (next_url := page.get("next")) and len(items) < hard_limit:
        # Ensure we don't bother getting more items than we need
        limit = min(50, hard_limit - len(items))
        next_url = sub(r"(?<=limit=)(\d{1,2})(?=&?)", str(limit), next_url)

        res: SearchResponse | AnyPaginatedResponse = self.get_json_response(next_url)  # type: ignore[assignment]
        page = (
            cast(SearchResponse, res)[top_level_key]
            if top_level_key
            else cast(AnyPaginatedResponse, res)
        )

        page_items: (
            list[AlbumSummaryJson]
            | list[DeviceJson]
            | list[ArtistSummaryJson]
            | list[PlaylistSummaryJson]
            | list[TrackFullJson]
        ) = page.get(list_key, [])
        if limit_func is None:
            items.extend(page_items)
        else:
            # Initialise `limit_reached` to False, and then it will be set to
            # True on the first matching item. This will then cause the loop to
            # skip subsequent items - not as good as a `break` but still kind of
            # elegant imho!
            limit_reached = False
            items.extend(
                [
                    item
                    for item in page_items
                    if (not (limit_reached := (limit_reached or limit_func(item))))
                ],
            )
            if limit_reached:
                return items

    return items

get_playlist_by_id(id_)

Get a playlist from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the playlist

required

Returns:

Name Type Description
Playlist Playlist

an instantiated Playlist, from the API's response

Source code in wg_utilities/clients/spotify.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def get_playlist_by_id(self, id_: str) -> Playlist:
    """Get a playlist from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the playlist

    Returns:
        Playlist: an instantiated Playlist, from the API's response
    """

    if hasattr(self, "_current_user") and hasattr(self.current_user, "_playlists"):
        for playlist in self.current_user.playlists:
            if playlist.id == id_:
                return playlist

    return Playlist.from_json_response(
        self.get_json_response(f"/playlists/{id_}"),
        spotify_client=self,
    )

get_track_by_id(id_)

Get a track from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the track

required

Returns:

Name Type Description
Track Track

an instantiated Track, from the API's response

Source code in wg_utilities/clients/spotify.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def get_track_by_id(self, id_: str) -> Track:
    """Get a track from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the track

    Returns:
        Track: an instantiated Track, from the API's response
    """

    return Track.from_json_response(
        self.get_json_response(f"/tracks/{id_}"),
        spotify_client=self,
    )

search(search_term, *, entity_types=(), get_best_match_only=False)

Search Spotify for a given search term.

Parameters:

Name Type Description Default
search_term str

the term to use as the base of the search

required
entity_types str

the types of entity to search for. Must be one of SpotifyClient.SEARCH_TYPES

()
get_best_match_only bool

return a single entity from the top of the list, rather than all matches

False

Returns:

Name Type Description
Artist | Playlist | Track | Album | None | ParsedSearchResponse

Artist | Playlist | Track | Album: a single entity if the best match flag is set

dict Artist | Playlist | Track | Album | None | ParsedSearchResponse

a dict of entities, by type

Raises:

Type Description
ValueError

if multiple entity types have been requested but the best match flag is true

ValueError

if one of entity_types is an invalid value

Source code in wg_utilities/clients/spotify.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def search(
    self,
    search_term: str,
    *,
    entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
    get_best_match_only: bool = False,
) -> Artist | Playlist | Track | Album | None | ParsedSearchResponse:
    """Search Spotify for a given search term.

    Args:
        search_term (str): the term to use as the base of the search
        entity_types (str): the types of entity to search for. Must be one of
            SpotifyClient.SEARCH_TYPES
        get_best_match_only (bool): return a single entity from the top of the
            list, rather than all matches

    Returns:
        Artist | Playlist | Track | Album: a single entity if the best match flag
            is set
        dict: a dict of entities, by type

    Raises:
        ValueError: if multiple entity types have been requested but the best match
            flag is true
        ValueError: if one of entity_types is an invalid value
    """

    entity_types = entity_types or self.SEARCH_TYPES

    if get_best_match_only is True and len(entity_types) != 1:
        raise ValueError(
            "Exactly one entity type must be requested if `get_best_match_only`"
            " is True",
        )

    entity_type: Literal["artist", "playlist", "track", "album"]
    for entity_type in entity_types:
        if entity_type not in self.SEARCH_TYPES:
            raise ValueError(
                f"Unexpected value for entity type: '{entity_type}'. Must be"
                f" one of {self.SEARCH_TYPES!r}",
            )

    res: SearchResponse = self.get_json_response(  # type: ignore[assignment]
        "/search",
        params={
            "query": search_term,
            "type": ",".join(entity_types),
            "limit": 1 if get_best_match_only else 50,
        },
    )

    entity_instances: ParsedSearchResponse = {}

    res_entity_type: Literal["albums", "artists", "playlists", "tracks"]
    entities_json: (
        PaginatedResponseAlbums
        | PaginatedResponseArtists
        | PaginatedResponsePlaylists
        | PaginatedResponseTracks
    )
    for res_entity_type, entities_json in res.items():  # type: ignore[assignment]
        instance_class: type[Album] | type[Artist] | type[Playlist] | type[Track] = {  # type: ignore[assignment]
            "albums": Album,
            "artists": Artist,
            "playlists": Playlist,
            "tracks": Track,
        }[res_entity_type]

        if get_best_match_only:
            try:
                # Take the entity off the top of the list
                return instance_class.from_json_response(
                    entities_json["items"][0],
                    spotify_client=self,
                )
            except LookupError:
                return None

        entity_instances.setdefault(res_entity_type, []).extend(
            [
                instance_class.from_json_response(entity_json, spotify_client=self)  # type: ignore[misc]
                for entity_json in entities_json.get("items", [])
            ],
        )

        # Each entity type has its own type-specific next URL
        if (next_url := entities_json.get("next")) is not None:
            entity_instances[res_entity_type].extend(
                [
                    instance_class.from_json_response(  # type: ignore[misc]
                        item,
                        spotify_client=self,
                    )
                    for item in self.get_items(
                        next_url,
                        top_level_key=res_entity_type,
                    )
                ],
            )

    return entity_instances

TrueLayerClient

Bases: OAuthClient[dict[Literal['results'], list[TrueLayerEntityJson]]]

Custom client for interacting with TrueLayer's APIs.

Source code in wg_utilities/clients/truelayer.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
class TrueLayerClient(OAuthClient[dict[Literal["results"], list[TrueLayerEntityJson]]]):
    """Custom client for interacting with TrueLayer's APIs."""

    AUTH_LINK_BASE = "https://auth.truelayer.com/"
    ACCESS_TOKEN_ENDPOINT = "https://auth.truelayer.com/connect/token"  # noqa: S105
    BASE_URL = "https://api.truelayer.com"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "info",
        "accounts",
        "balance",
        "cards",
        "transactions",
        "direct_debits",
        "standing_orders",
        "offline_access",
    ]

    def __init__(  # noqa: PLR0913
        self,
        *,
        client_id: str,
        client_secret: str,
        log_requests: bool = False,
        creds_cache_path: Path | None = None,
        creds_cache_dir: Path | None = None,
        scopes: list[str] | None = None,
        oauth_login_redirect_host: str = "localhost",
        oauth_redirect_uri_override: str | None = None,
        headless_auth_link_callback: Callable[[str], None] | None = None,
        use_existing_credentials_only: bool = False,
        validate_request_success: bool = True,
        bank: Bank,
    ):
        super().__init__(
            base_url=self.BASE_URL,
            access_token_endpoint=self.ACCESS_TOKEN_ENDPOINT,
            auth_link_base=self.AUTH_LINK_BASE,
            client_id=client_id,
            client_secret=client_secret,
            log_requests=log_requests,
            creds_cache_path=creds_cache_path,
            creds_cache_dir=creds_cache_dir,
            scopes=scopes or self.DEFAULT_SCOPES,
            oauth_login_redirect_host=oauth_login_redirect_host,
            oauth_redirect_uri_override=oauth_redirect_uri_override,
            headless_auth_link_callback=headless_auth_link_callback,
            validate_request_success=validate_request_success,
            use_existing_credentials_only=use_existing_credentials_only,
        )

        self.bank = bank

    def _get_entity_by_id(
        self,
        entity_id: str,
        entity_class: type[AccountOrCard],
    ) -> AccountOrCard | None:
        """Get entity info based on a given ID.

        Args:
            entity_id (str): the unique ID for the account/card
            entity_class (type): the class to instantiate with the returned info

        Returns:
            Union([Account, Card]): a Card instance with associated info

        Raises:
            HTTPError: if a HTTPError is raised by the request, and it's not because
                the ID wasn't found
            ValueError: if >1 result is returned from the TrueLayer API
        """
        try:
            results = self.get_json_response(
                f"/data/v1/{entity_class.__name__.lower()}s/{entity_id}",
            ).get("results", [])
        except HTTPError as exc:
            if (
                exc.response is not None
                and exc.response.json().get("error") == "account_not_found"
            ):
                return None
            raise

        if len(results) != 1:
            raise ValueError(
                f"Unexpected number of results when getting {entity_class.__name__}:"
                f" {len(results)}",
            )

        return entity_class.from_json_response(results[0], truelayer_client=self)

    def _list_entities(self, entity_class: type[AccountOrCard]) -> list[AccountOrCard]:
        """List all accounts/cards under the given bank account.

        Args:
            entity_class (type): the class to instantiate with the returned info

        Returns:
            list[Union([Account, Card])]: a list of Account/Card instances with
                associated info

        Raises:
            HTTPError: if a HTTPError is raised by the `_get` method, but it's not a 501
        """
        try:
            res = self.get_json_response(f"/data/v1/{entity_class.__name__.lower()}s")
        except HTTPError as exc:
            if (
                exc.response is not None
                and exc.response.json().get("error") == "endpoint_not_supported"
            ):
                LOGGER.warning(
                    "{entity_class.__name__}s endpoint not supported by %s",
                    self.bank.value,
                )
                res = {}
            else:
                raise

        return [
            entity_class.from_json_response(result, truelayer_client=self)
            for result in res.get("results", [])
        ]

    def get_account_by_id(
        self,
        account_id: str,
    ) -> Account | None:
        """Get an Account instance based on the ID.

        Args:
            account_id (str): the ID of the card

        Returns:
            Account: an Account instance, with all relevant info
        """
        return self._get_entity_by_id(account_id, Account)

    def get_card_by_id(
        self,
        card_id: str,
    ) -> Card | None:
        """Get a Card instance based on the ID.

        Args:
            card_id (str): the ID of the card

        Returns:
            Card: a Card instance, with all relevant info
        """
        return self._get_entity_by_id(card_id, Card)

    def list_accounts(self) -> list[Account]:
        """List all accounts under the given bank account.

        Returns:
            list[Account]: Account instances, containing all related info
        """
        return self._list_entities(Account)

    def list_cards(self) -> list[Card]:
        """List all accounts under the given bank account.

        Returns:
            list[Account]: Account instances, containing all related info
        """
        return self._list_entities(Card)

    @property
    def _creds_rel_file_path(self) -> Path | None:
        """Get the credentials cache filepath relative to the cache directory.

        TrueLayer shares the same Client ID for all banks, so this overrides the default
        to separate credentials by bank.
        """

        try:
            client_id = self._client_id or self._credentials.client_id
        except AttributeError:
            return None

        return Path(type(self).__name__, client_id, f"{self.bank.name.lower()}.json")

get_account_by_id(account_id)

Get an Account instance based on the ID.

Parameters:

Name Type Description Default
account_id str

the ID of the card

required

Returns:

Name Type Description
Account Account | None

an Account instance, with all relevant info

Source code in wg_utilities/clients/truelayer.py
687
688
689
690
691
692
693
694
695
696
697
698
699
def get_account_by_id(
    self,
    account_id: str,
) -> Account | None:
    """Get an Account instance based on the ID.

    Args:
        account_id (str): the ID of the card

    Returns:
        Account: an Account instance, with all relevant info
    """
    return self._get_entity_by_id(account_id, Account)

get_card_by_id(card_id)

Get a Card instance based on the ID.

Parameters:

Name Type Description Default
card_id str

the ID of the card

required

Returns:

Name Type Description
Card Card | None

a Card instance, with all relevant info

Source code in wg_utilities/clients/truelayer.py
701
702
703
704
705
706
707
708
709
710
711
712
713
def get_card_by_id(
    self,
    card_id: str,
) -> Card | None:
    """Get a Card instance based on the ID.

    Args:
        card_id (str): the ID of the card

    Returns:
        Card: a Card instance, with all relevant info
    """
    return self._get_entity_by_id(card_id, Card)

list_accounts()

List all accounts under the given bank account.

Returns:

Type Description
list[Account]

list[Account]: Account instances, containing all related info

Source code in wg_utilities/clients/truelayer.py
715
716
717
718
719
720
721
def list_accounts(self) -> list[Account]:
    """List all accounts under the given bank account.

    Returns:
        list[Account]: Account instances, containing all related info
    """
    return self._list_entities(Account)

list_cards()

List all accounts under the given bank account.

Returns:

Type Description
list[Card]

list[Account]: Account instances, containing all related info

Source code in wg_utilities/clients/truelayer.py
723
724
725
726
727
728
729
def list_cards(self) -> list[Card]:
    """List all accounts under the given bank account.

    Returns:
        list[Account]: Account instances, containing all related info
    """
    return self._list_entities(Card)

google_calendar

Custom client for interacting with Google's Calendar API.

Calendar

Bases: GoogleCalendarEntity

Class for Google calendar instances.

Source code in wg_utilities/clients/google_calendar.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
class Calendar(GoogleCalendarEntity):
    """Class for Google calendar instances."""

    access_role: Literal["freeBusyReader", "reader", "writer", "owner"] | None = Field(
        None,
        alias="accessRole",
    )
    background_color: str | None = Field(None, alias="backgroundColor")
    color_id: str | None = Field(None, alias="colorId")
    conference_properties: dict[
        Literal["allowedConferenceSolutionTypes"],
        list[Literal["eventHangout", "eventNamedHangout", "hangoutsMeet"]],
    ] = Field(alias="conferenceProperties")
    default_reminders: list[_Reminder] = Field(
        alias="defaultReminders",
        default_factory=list,
    )
    deleted: bool = False
    foreground_color: str | None = Field(None, alias="foregroundColor")
    hidden: bool = False
    kind: Literal["calendar#calendar", "calendar#calendarListEntry"]
    notification_settings: dict[
        Literal["notifications"],
        list[_Notification],
    ] = Field(
        alias="notificationSettings",
        default_factory=list,  # type: ignore[assignment]
    )
    primary: bool = False
    selected: bool = False
    summary_override: str | None = Field(None, alias="summaryOverride")
    timezone: tzinfo = Field(alias="timeZone")

    # mypy can't get this type from the parent class for some reason...
    google_client: GoogleCalendarClient = Field(exclude=True)

    @field_validator("timezone", mode="before")
    @classmethod
    def validate_timezone(cls, value: str) -> tzinfo:
        """Convert the timezone string into a tzinfo object."""
        if isinstance(value, tzinfo):
            return value

        return ZoneInfo(value)

    @field_serializer("timezone", mode="plain", when_used="json", check_fields=True)
    def serialize_timezone(self, tz: tzinfo) -> str:
        """Serialize the timezone to a string."""

        return str(tz)

    def get_event_by_id(self, event_id: str) -> Event:
        """Get an event by its ID.

        Args:
            event_id (str): ID of the event to get

        Returns:
            Event: Event object
        """

        return self.google_client.get_event_by_id(event_id, calendar=self)

    def get_events(
        self,
        page_size: int = 500,
        order_by: Literal["updated", "startTime"] = "updated",
        from_datetime: datetime_ | None = None,
        to_datetime: datetime_ | None = None,
        day_limit: int | None = None,
        *,
        combine_recurring_events: bool = False,
    ) -> list[Event]:
        """Retrieve events from the calendar according to a set of criteria.

        Args:
            page_size (int): the number of records to return on a single response page
            order_by (Literal["updated", "startTime"]): the order of the events
                returned within the result
            from_datetime (datetime): lower bound (exclusive) for an event's end time
                to filter by. Defaults to 90 days ago.
            to_datetime (datetime): upper bound (exclusive) for an event's start time
                to filter by. Defaults to now.
            combine_recurring_events (bool): whether to expand recurring events into
                instances and only return single one-off events and instances of recurring
                events, but not the underlying recurring events themselves
            day_limit (int): the maximum number of days to return events for.

        Returns:
            List[Event]: a list of Event instances

        Raises:
            ValueError: if the time parameters are invalid
        """
        params = {
            "maxResults": page_size,
            "orderBy": order_by,
            "singleEvents": str(not combine_recurring_events),
        }
        if from_datetime or to_datetime or day_limit:
            to_datetime = to_datetime or datetime_.utcnow()
            from_datetime = from_datetime or to_datetime - timedelta(days=day_limit or 90)

            if day_limit is not None:
                # Force the to_datetime to be within the day_limit
                to_datetime = min(to_datetime, from_datetime + timedelta(days=day_limit))

            if from_datetime.tzinfo is None:
                from_datetime = from_datetime.replace(tzinfo=UTC)

            if to_datetime.tzinfo is None:
                to_datetime = to_datetime.replace(tzinfo=UTC)

            params["timeMin"] = from_datetime.isoformat()
            params["timeMax"] = to_datetime.isoformat()

        return [
            Event.from_json_response(
                item,
                calendar=self,
                google_client=self.google_client,
            )
            for item in self.google_client.get_items(
                f"{self.google_client.base_url}/calendars/{self.id}/events",
                params=params,  # type: ignore[arg-type]
            )
        ]

    def __str__(self) -> str:
        """Return the calendar name."""
        return self.summary

__str__()

Return the calendar name.

Source code in wg_utilities/clients/google_calendar.py
351
352
353
def __str__(self) -> str:
    """Return the calendar name."""
    return self.summary

get_event_by_id(event_id)

Get an event by its ID.

Parameters:

Name Type Description Default
event_id str

ID of the event to get

required

Returns:

Name Type Description
Event Event

Event object

Source code in wg_utilities/clients/google_calendar.py
274
275
276
277
278
279
280
281
282
283
284
def get_event_by_id(self, event_id: str) -> Event:
    """Get an event by its ID.

    Args:
        event_id (str): ID of the event to get

    Returns:
        Event: Event object
    """

    return self.google_client.get_event_by_id(event_id, calendar=self)

get_events(page_size=500, order_by='updated', from_datetime=None, to_datetime=None, day_limit=None, *, combine_recurring_events=False)

Retrieve events from the calendar according to a set of criteria.

Parameters:

Name Type Description Default
page_size int

the number of records to return on a single response page

500
order_by Literal['updated', 'startTime']

the order of the events returned within the result

'updated'
from_datetime datetime

lower bound (exclusive) for an event's end time to filter by. Defaults to 90 days ago.

None
to_datetime datetime

upper bound (exclusive) for an event's start time to filter by. Defaults to now.

None
combine_recurring_events bool

whether to expand recurring events into instances and only return single one-off events and instances of recurring events, but not the underlying recurring events themselves

False
day_limit int

the maximum number of days to return events for.

None

Returns:

Type Description
list[Event]

List[Event]: a list of Event instances

Raises:

Type Description
ValueError

if the time parameters are invalid

Source code in wg_utilities/clients/google_calendar.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def get_events(
    self,
    page_size: int = 500,
    order_by: Literal["updated", "startTime"] = "updated",
    from_datetime: datetime_ | None = None,
    to_datetime: datetime_ | None = None,
    day_limit: int | None = None,
    *,
    combine_recurring_events: bool = False,
) -> list[Event]:
    """Retrieve events from the calendar according to a set of criteria.

    Args:
        page_size (int): the number of records to return on a single response page
        order_by (Literal["updated", "startTime"]): the order of the events
            returned within the result
        from_datetime (datetime): lower bound (exclusive) for an event's end time
            to filter by. Defaults to 90 days ago.
        to_datetime (datetime): upper bound (exclusive) for an event's start time
            to filter by. Defaults to now.
        combine_recurring_events (bool): whether to expand recurring events into
            instances and only return single one-off events and instances of recurring
            events, but not the underlying recurring events themselves
        day_limit (int): the maximum number of days to return events for.

    Returns:
        List[Event]: a list of Event instances

    Raises:
        ValueError: if the time parameters are invalid
    """
    params = {
        "maxResults": page_size,
        "orderBy": order_by,
        "singleEvents": str(not combine_recurring_events),
    }
    if from_datetime or to_datetime or day_limit:
        to_datetime = to_datetime or datetime_.utcnow()
        from_datetime = from_datetime or to_datetime - timedelta(days=day_limit or 90)

        if day_limit is not None:
            # Force the to_datetime to be within the day_limit
            to_datetime = min(to_datetime, from_datetime + timedelta(days=day_limit))

        if from_datetime.tzinfo is None:
            from_datetime = from_datetime.replace(tzinfo=UTC)

        if to_datetime.tzinfo is None:
            to_datetime = to_datetime.replace(tzinfo=UTC)

        params["timeMin"] = from_datetime.isoformat()
        params["timeMax"] = to_datetime.isoformat()

    return [
        Event.from_json_response(
            item,
            calendar=self,
            google_client=self.google_client,
        )
        for item in self.google_client.get_items(
            f"{self.google_client.base_url}/calendars/{self.id}/events",
            params=params,  # type: ignore[arg-type]
        )
    ]

serialize_timezone(tz)

Serialize the timezone to a string.

Source code in wg_utilities/clients/google_calendar.py
268
269
270
271
272
@field_serializer("timezone", mode="plain", when_used="json", check_fields=True)
def serialize_timezone(self, tz: tzinfo) -> str:
    """Serialize the timezone to a string."""

    return str(tz)

validate_timezone(value) classmethod

Convert the timezone string into a tzinfo object.

Source code in wg_utilities/clients/google_calendar.py
259
260
261
262
263
264
265
266
@field_validator("timezone", mode="before")
@classmethod
def validate_timezone(cls, value: str) -> tzinfo:
    """Convert the timezone string into a tzinfo object."""
    if isinstance(value, tzinfo):
        return value

    return ZoneInfo(value)

CalendarJson

Bases: TypedDict

JSON representation of a Calendar.

Source code in wg_utilities/clients/google_calendar.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
class CalendarJson(TypedDict):
    """JSON representation of a Calendar."""

    description: str | None
    etag: str
    id: str
    location: str | None
    summary: str

    kind: Literal["calendar#calendar"]
    timeZone: str
    conferenceProperties: dict[
        Literal["allowedConferenceSolutionTypes"],
        list[Literal["eventHangout", "eventNamedHangout", "hangoutsMeet"]],
    ]

Event

Bases: GoogleCalendarEntity

Class for Google Calendar events.

Source code in wg_utilities/clients/google_calendar.py
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
class Event(GoogleCalendarEntity):
    """Class for Google Calendar events."""

    summary: str = "(No Title)"

    attachments: list[dict[str, str]] | None = None
    attendees: list[_Attendee] = Field(default_factory=list)
    attendees_omitted: bool | None = Field(None, alias="attendeesOmitted")
    created: datetime_
    color_id: str | None = Field(None, alias="colorId")
    conference_data: _ConferenceData | None = Field(None, alias="conferenceData")
    creator: _Creator
    end: _StartEndDatetime
    end_time_unspecified: bool | None = Field(None, alias="endTimeUnspecified")
    event_type: EventType = Field(alias="eventType")
    extended_properties: dict[str, dict[str, str]] | None = Field(
        None,
        alias="extendedProperties",
    )
    guests_can_invite_others: bool | None = Field(None, alias="guestsCanInviteOthers")
    guests_can_modify: bool | None = Field(None, alias="guestsCanModify")
    guests_can_see_other_guests: bool | None = Field(
        None,
        alias="guestsCanSeeOtherGuests",
    )
    hangout_link: str | None = Field(None, alias="hangoutLink")
    html_link: str = Field(alias="htmlLink")
    ical_uid: str = Field(alias="iCalUID")
    kind: Literal["calendar#event"]
    locked: bool | None = None
    organizer: dict[str, bool | str]
    original_start_time: dict[str, str] | None = Field(None, alias="originalStartTime")
    private_copy: bool | None = Field(None, alias="privateCopy")
    recurrence: list[str] | None = None
    recurring_event_id: str | None = Field(None, alias="recurringEventId")
    reminders: _EventReminders | None = None
    sequence: int
    source: dict[str, str] | None = None
    start: _StartEndDatetime
    status: Literal["cancelled", "confirmed", "tentative"] | None = None
    transparency: str | None = None  # != transparent
    updated: datetime_
    visibility: Literal["default", "public", "private", "confidential"] | None = None

    calendar: Calendar

    def delete(self) -> None:
        """Delete the event from the host calendar."""
        self.google_client.delete_event_by_id(event_id=self.id, calendar=self.calendar)

    @property
    def response_status(self) -> ResponseStatus:
        """User's response status.

        Returns:
            ResponseStatus: the response status for the authenticated user
        """
        for attendee in self.attendees:
            if attendee.self is True:
                return attendee.response_status

        # Own events don't always have attendees
        if self.creator.self:
            return ResponseStatus.ACCEPTED

        return ResponseStatus.UNKNOWN

    def __gt__(self, other: Event) -> bool:
        """Compare two events by their start time, end time, or name."""

        if not isinstance(other, Event):
            return NotImplemented

        return (self.start.datetime, self.end.datetime, self.summary) > (
            other.start.datetime,
            other.end.datetime,
            other.summary,
        )

    def __lt__(self, other: Event) -> bool:
        """Compare two events by their start time, end time, or name."""

        if not isinstance(other, Event):
            return NotImplemented

        return (self.start.datetime, self.end.datetime, self.summary) < (
            other.start.datetime,
            other.end.datetime,
            other.summary,
        )

    def __str__(self) -> str:
        """Return the event's summary and start/end datetimes."""
        return (
            f"{self.summary} ("
            f"{self.start.datetime.isoformat()} - "
            f"{self.end.datetime.isoformat()})"
        )

response_status: ResponseStatus property

User's response status.

Returns:

Name Type Description
ResponseStatus ResponseStatus

the response status for the authenticated user

__gt__(other)

Compare two events by their start time, end time, or name.

Source code in wg_utilities/clients/google_calendar.py
471
472
473
474
475
476
477
478
479
480
481
def __gt__(self, other: Event) -> bool:
    """Compare two events by their start time, end time, or name."""

    if not isinstance(other, Event):
        return NotImplemented

    return (self.start.datetime, self.end.datetime, self.summary) > (
        other.start.datetime,
        other.end.datetime,
        other.summary,
    )

__lt__(other)

Compare two events by their start time, end time, or name.

Source code in wg_utilities/clients/google_calendar.py
483
484
485
486
487
488
489
490
491
492
493
def __lt__(self, other: Event) -> bool:
    """Compare two events by their start time, end time, or name."""

    if not isinstance(other, Event):
        return NotImplemented

    return (self.start.datetime, self.end.datetime, self.summary) < (
        other.start.datetime,
        other.end.datetime,
        other.summary,
    )

__str__()

Return the event's summary and start/end datetimes.

Source code in wg_utilities/clients/google_calendar.py
495
496
497
498
499
500
501
def __str__(self) -> str:
    """Return the event's summary and start/end datetimes."""
    return (
        f"{self.summary} ("
        f"{self.start.datetime.isoformat()} - "
        f"{self.end.datetime.isoformat()})"
    )

delete()

Delete the event from the host calendar.

Source code in wg_utilities/clients/google_calendar.py
450
451
452
def delete(self) -> None:
    """Delete the event from the host calendar."""
    self.google_client.delete_event_by_id(event_id=self.id, calendar=self.calendar)

EventJson

Bases: TypedDict

JSON representation of an Event.

Source code in wg_utilities/clients/google_calendar.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
class EventJson(TypedDict, total=False):
    """JSON representation of an Event."""

    description: str | None
    etag: str
    id: str
    location: str | None
    summary: str | None

    attachments: list[dict[str, str]] | None
    attendees: list[_Attendee] | None
    attendeesOmitted: bool | None
    created: datetime_
    colorId: str | None
    conferenceData: _ConferenceData | None
    creator: _Creator
    end: _StartEndDatetime
    endTimeUnspecified: bool | None
    eventType: EventType  # "default"
    extendedProperties: dict[str, dict[str, str]] | None
    guestsCanInviteOthers: bool | None
    guestsCanModify: bool | None
    guestsCanSeeOtherGuests: bool | None
    hangoutLink: str | None
    htmlLink: str
    iCalUID: str
    kind: Literal["calendar#event"]
    locked: bool | None
    organizer: dict[str, bool | str]
    original_start_time: dict[str, str] | None
    privateCopy: bool | None
    recurrence: list[str] | None
    recurringEventId: str | None
    reminders: _EventReminders | None
    sequence: int
    source: dict[str, str] | None
    start: _StartEndDatetime
    status: Literal["cancelled", "confirmed", "tentative"] | None
    transparency: str | None
    updated: datetime_
    visibility: Literal["default", "public", "private", "confidential"] | None

EventType

Bases: StrEnum

Enumeration for event types.

Source code in wg_utilities/clients/google_calendar.py
37
38
39
40
41
42
class EventType(StrEnum):
    """Enumeration for event types."""

    DEFAULT = "default"
    FOCUS_TIME = "focusTime"
    OUT_OF_OFFICE = "outOfOffice"

GoogleCalendarClient

Bases: GoogleClient[GoogleCalendarEntityJson]

Custom client specifically for Google's Calendar API.

Source code in wg_utilities/clients/google_calendar.py
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
class GoogleCalendarClient(GoogleClient[GoogleCalendarEntityJson]):
    """Custom client specifically for Google's Calendar API."""

    BASE_URL = "https://www.googleapis.com/calendar/v3"

    DEFAULT_PARAMS: ClassVar[
        dict[StrBytIntFlt, StrBytIntFlt | Iterable[StrBytIntFlt] | None]
    ] = {
        "maxResults": "250",
    }

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/calendar",
        "https://www.googleapis.com/auth/calendar.events",
    ]

    _primary_calendar: Calendar

    def create_event(
        self,
        summary: str,
        start_datetime: datetime_ | date_,
        end_datetime: datetime_ | date_,
        tz: str | None = None,
        calendar: Calendar | None = None,
        extra_params: dict[str, str] | None = None,
    ) -> Event:
        """Create an event.

        Args:
            summary (str): the summary (title) of the event
            start_datetime (Union[datetime, date]): when the event starts
            end_datetime (Union[datetime, date]): when the event ends
            tz (str): the timezone which the event is in (IANA database name)
            calendar (Calendar): the calendar to add the event to
            extra_params (dict): any extra params to pass in the request

        Returns:
            Event: a new event instance, fresh out of the oven

        Raises:
            TypeError: if the start/end datetime params are not the correct type
        """

        calendar = calendar or self.primary_calendar
        tz = tz or str(get_localzone())

        start_params = {
            "timeZone": tz,
        }

        if isinstance(start_datetime, datetime_):
            start_params["dateTime"] = start_datetime.isoformat()
        elif isinstance(start_datetime, date_):
            start_params["date"] = start_datetime.isoformat()
        else:
            raise TypeError("`start_datetime` must be either a date or a datetime")

        end_params = {
            "timeZone": tz,
        }

        if isinstance(end_datetime, datetime_):
            end_params["dateTime"] = end_datetime.isoformat()
        elif isinstance(end_datetime, date_):
            end_params["date"] = end_datetime.isoformat()
        else:
            raise TypeError("`end_datetime` must be either a date or a datetime")

        event_json = self.post_json_response(
            f"/calendars/{calendar.id}/events",
            json={
                "summary": summary,
                "start": start_params,
                "end": end_params,
                **(extra_params or {}),
            },
            params={"maxResults": None},
        )

        return Event.from_json_response(event_json, calendar=calendar, google_client=self)

    def delete_event_by_id(self, event_id: str, calendar: Calendar | None = None) -> None:
        """Delete an event from a calendar.

        Args:
            event_id (str): the ID of the event to delete
            calendar (Calendar): the calendar being updated
        """
        calendar = calendar or self.primary_calendar

        res = delete(
            f"{self.base_url}/calendars/{calendar.id}/events/{event_id}",
            headers=self.request_headers,
            timeout=10,
        )

        res.raise_for_status()

    def get_event_by_id(
        self,
        event_id: str,
        *,
        calendar: Calendar | None = None,
    ) -> Event:
        """Get a specific event by ID.

        Args:
            event_id (str): the ID of the event to delete
            calendar (Calendar): the calendar being updated

        Returns:
            Event: an Event instance with all relevant attributes
        """
        calendar = calendar or self.primary_calendar

        return Event.from_json_response(
            self.get_json_response(
                f"/calendars/{calendar.id}/events/{event_id}",
                params={"maxResults": None},
            ),
            calendar=calendar,
            google_client=self,
        )

    @property
    def calendar_list(self) -> list[Calendar]:
        """List of calendars.

        Returns:
            list: a list of Calendar instances that the user has access to
        """
        return [
            Calendar.from_json_response(cal_json, google_client=self)
            for cal_json in self.get_items(
                "/users/me/calendarList",
                params={"maxResults": None},
            )
        ]

    @property
    def primary_calendar(self) -> Calendar:
        """Primary calendar for the user.

        Returns:
            Calendar: the current user's primary calendar
        """
        if not hasattr(self, "_primary_calendar"):
            self._primary_calendar = Calendar.from_json_response(
                self.get_json_response("/calendars/primary", params={"maxResults": None}),
                google_client=self,
            )

        return self._primary_calendar

calendar_list: list[Calendar] property

List of calendars.

Returns:

Name Type Description
list list[Calendar]

a list of Calendar instances that the user has access to

primary_calendar: Calendar property

Primary calendar for the user.

Returns:

Name Type Description
Calendar Calendar

the current user's primary calendar

create_event(summary, start_datetime, end_datetime, tz=None, calendar=None, extra_params=None)

Create an event.

Parameters:

Name Type Description Default
summary str

the summary (title) of the event

required
start_datetime Union[datetime, date]

when the event starts

required
end_datetime Union[datetime, date]

when the event ends

required
tz str

the timezone which the event is in (IANA database name)

None
calendar Calendar

the calendar to add the event to

None
extra_params dict

any extra params to pass in the request

None

Returns:

Name Type Description
Event Event

a new event instance, fresh out of the oven

Raises:

Type Description
TypeError

if the start/end datetime params are not the correct type

Source code in wg_utilities/clients/google_calendar.py
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
def create_event(
    self,
    summary: str,
    start_datetime: datetime_ | date_,
    end_datetime: datetime_ | date_,
    tz: str | None = None,
    calendar: Calendar | None = None,
    extra_params: dict[str, str] | None = None,
) -> Event:
    """Create an event.

    Args:
        summary (str): the summary (title) of the event
        start_datetime (Union[datetime, date]): when the event starts
        end_datetime (Union[datetime, date]): when the event ends
        tz (str): the timezone which the event is in (IANA database name)
        calendar (Calendar): the calendar to add the event to
        extra_params (dict): any extra params to pass in the request

    Returns:
        Event: a new event instance, fresh out of the oven

    Raises:
        TypeError: if the start/end datetime params are not the correct type
    """

    calendar = calendar or self.primary_calendar
    tz = tz or str(get_localzone())

    start_params = {
        "timeZone": tz,
    }

    if isinstance(start_datetime, datetime_):
        start_params["dateTime"] = start_datetime.isoformat()
    elif isinstance(start_datetime, date_):
        start_params["date"] = start_datetime.isoformat()
    else:
        raise TypeError("`start_datetime` must be either a date or a datetime")

    end_params = {
        "timeZone": tz,
    }

    if isinstance(end_datetime, datetime_):
        end_params["dateTime"] = end_datetime.isoformat()
    elif isinstance(end_datetime, date_):
        end_params["date"] = end_datetime.isoformat()
    else:
        raise TypeError("`end_datetime` must be either a date or a datetime")

    event_json = self.post_json_response(
        f"/calendars/{calendar.id}/events",
        json={
            "summary": summary,
            "start": start_params,
            "end": end_params,
            **(extra_params or {}),
        },
        params={"maxResults": None},
    )

    return Event.from_json_response(event_json, calendar=calendar, google_client=self)

delete_event_by_id(event_id, calendar=None)

Delete an event from a calendar.

Parameters:

Name Type Description Default
event_id str

the ID of the event to delete

required
calendar Calendar

the calendar being updated

None
Source code in wg_utilities/clients/google_calendar.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
def delete_event_by_id(self, event_id: str, calendar: Calendar | None = None) -> None:
    """Delete an event from a calendar.

    Args:
        event_id (str): the ID of the event to delete
        calendar (Calendar): the calendar being updated
    """
    calendar = calendar or self.primary_calendar

    res = delete(
        f"{self.base_url}/calendars/{calendar.id}/events/{event_id}",
        headers=self.request_headers,
        timeout=10,
    )

    res.raise_for_status()

get_event_by_id(event_id, *, calendar=None)

Get a specific event by ID.

Parameters:

Name Type Description Default
event_id str

the ID of the event to delete

required
calendar Calendar

the calendar being updated

None

Returns:

Name Type Description
Event Event

an Event instance with all relevant attributes

Source code in wg_utilities/clients/google_calendar.py
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
def get_event_by_id(
    self,
    event_id: str,
    *,
    calendar: Calendar | None = None,
) -> Event:
    """Get a specific event by ID.

    Args:
        event_id (str): the ID of the event to delete
        calendar (Calendar): the calendar being updated

    Returns:
        Event: an Event instance with all relevant attributes
    """
    calendar = calendar or self.primary_calendar

    return Event.from_json_response(
        self.get_json_response(
            f"/calendars/{calendar.id}/events/{event_id}",
            params={"maxResults": None},
        ),
        calendar=calendar,
        google_client=self,
    )

GoogleCalendarEntity

Bases: BaseModelWithConfig

Base class for Google Calendar entities.

Source code in wg_utilities/clients/google_calendar.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class GoogleCalendarEntity(BaseModelWithConfig):
    """Base class for Google Calendar entities."""

    description: str | None = None
    etag: str
    id: str
    location: str | None = None
    summary: str

    google_client: GoogleCalendarClient = Field(exclude=True)

    @classmethod
    def from_json_response(
        cls,
        value: GoogleCalendarEntityJson,
        google_client: GoogleCalendarClient,
        calendar: Calendar | None = None,
    ) -> Self:
        """Create a Calendar/Event from a JSON response."""

        value_data: dict[str, Any] = {
            "google_client": google_client,
            **value,
        }

        if cls == Event:
            value_data["calendar"] = calendar

        return cls.model_validate(value_data)

    def __eq__(self, other: Any) -> bool:
        """Compare two GoogleCalendarEntity objects by ID."""
        if not isinstance(other, type(self)):
            return NotImplemented

        return self.id == other.id

__eq__(other)

Compare two GoogleCalendarEntity objects by ID.

Source code in wg_utilities/clients/google_calendar.py
195
196
197
198
199
200
def __eq__(self, other: Any) -> bool:
    """Compare two GoogleCalendarEntity objects by ID."""
    if not isinstance(other, type(self)):
        return NotImplemented

    return self.id == other.id

from_json_response(value, google_client, calendar=None) classmethod

Create a Calendar/Event from a JSON response.

Source code in wg_utilities/clients/google_calendar.py
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
@classmethod
def from_json_response(
    cls,
    value: GoogleCalendarEntityJson,
    google_client: GoogleCalendarClient,
    calendar: Calendar | None = None,
) -> Self:
    """Create a Calendar/Event from a JSON response."""

    value_data: dict[str, Any] = {
        "google_client": google_client,
        **value,
    }

    if cls == Event:
        value_data["calendar"] = calendar

    return cls.model_validate(value_data)

ResponseStatus

Bases: StrEnum

Enumeration for event attendee response statuses.

Source code in wg_utilities/clients/google_calendar.py
27
28
29
30
31
32
33
34
class ResponseStatus(StrEnum):
    """Enumeration for event attendee response statuses."""

    ACCEPTED = "accepted"
    DECLINED = "declined"
    TENTATIVE = "tentative"
    UNCONFIRMED = "needsAction"
    UNKNOWN = "unknown"

google_drive

Custom client for interacting with Google's Drive API.

Directory

Bases: File, _CanHaveChildren

A Google Drive directory - basically a File with extended functionality.

Source code in wg_utilities/clients/google_drive.py
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
class Directory(File, _CanHaveChildren):
    """A Google Drive directory - basically a File with extended functionality."""

    MIME_TYPE: ClassVar[Literal["application/vnd.google-apps.folder"]] = (
        "application/vnd.google-apps.folder"
    )

    kind: Literal[EntityKind.DIRECTORY] = Field(default=EntityKind.DIRECTORY)
    mime_type: Literal["application/vnd.google-apps.folder"] = Field(
        alias="mimeType",
        default=MIME_TYPE,
    )

    host_drive_: Drive = Field(exclude=True)

    @field_validator("kind", mode="before")
    @classmethod
    def _validate_kind(cls, value: str | None) -> str:
        """Set the kind to "drive#folder"."""

        # Directories are just a subtype of files, so `"drive#file"` is okay too
        if value not in (EntityKind.DIRECTORY, EntityKind.FILE):
            raise ValueError(f"Invalid kind for Directory: {value}")

        return EntityKind.DIRECTORY

    @field_validator("mime_type")
    @classmethod
    def _validate_mime_type(cls, mime_type: str) -> str:
        """Just an override for the parent class's validator."""

        return mime_type

    def __repr__(self) -> str:
        """Return a string representation of the directory."""
        return f"Directory(id={self.id!r}, name={self.name!r})"

__repr__()

Return a string representation of the directory.

Source code in wg_utilities/clients/google_drive.py
895
896
897
def __repr__(self) -> str:
    """Return a string representation of the directory."""
    return f"Directory(id={self.id!r}, name={self.name!r})"

Drive

Bases: _CanHaveChildren

A Google Drive: Drive - basically a Directory with extended functionality.

Source code in wg_utilities/clients/google_drive.py
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
class Drive(_CanHaveChildren):
    """A Google Drive: Drive - basically a Directory with extended functionality."""

    kind: Literal[EntityKind.DRIVE] = Field(default=EntityKind.DRIVE)
    mime_type: Literal["application/vnd.google-apps.folder"] = Field(
        alias="mimeType",
        default=Directory.MIME_TYPE,
    )

    # Optional, can be retrieved with the `describe` method or by getting the attribute
    background_image_file: _DriveBackgroundImageFile | None = Field(
        None,
        alias="backgroundImageFile",
    )
    background_image_link: str | None = Field(None, alias="backgroundImageLink")
    capabilities: _DriveCapabilities | None = None
    color_rgb: str | None = Field(None, alias="colorRgb")
    copy_requires_writer_permission: bool | None = Field(
        None,
        alias="copyRequiresWriterPermission",
    )
    created_time: datetime | None = Field(None, alias="createdTime")
    explicitly_trashed: bool | None = Field(None, alias="explicitlyTrashed")
    folder_color_rgb: str | None = Field(None, alias="folderColorRgb")
    has_thumbnail: bool | None = Field(None, alias="hasThumbnail")
    hidden: bool | None = None
    icon_link: str | None = Field(None, alias="iconLink")
    is_app_authorized: bool | None = Field(None, alias="isAppAuthorized")
    last_modifying_user: _User | None = Field(None, alias="lastModifyingUser")
    link_share_metadata: dict[
        Literal["securityUpdateEligible", "securityUpdateEnabled"],
        bool,
    ] = Field(alias="linkShareMetadata", default_factory=dict)
    modified_by_me: bool | None = Field(None, alias="modifiedByMe")
    modified_by_me_time: datetime | None = Field(None, alias="modifiedByMeTime")
    modified_time: datetime | None = Field(None, alias="modifiedTime")
    org_unit_id: str | None = Field(None, alias="orgUnitId")
    owned_by_me: bool | None = Field(None, alias="ownedByMe")
    owners: list[_User] = Field(default_factory=list)
    permissions: list[_Permission] = Field(default_factory=list)
    permission_ids: list[str] = Field(alias="permissionIds", default_factory=list)
    quota_bytes_used: float | None = Field(None, alias="quotaBytesUsed")
    restrictions: _DriveRestrictions | None = None
    shared: bool | None = None
    spaces: list[str] = Field(default_factory=list)
    starred: bool | None = None
    theme_id: str | None = Field(None, alias="themeId")
    thumbnail_version: int | None = Field(None, alias="thumbnailVersion")
    trashed: bool | None = None
    version: int | None = None
    viewed_by_me: bool | None = Field(None, alias="viewedByMe")
    viewers_can_copy_content: bool | None = Field(None, alias="viewersCanCopyContent")
    web_view_link: str | None = Field(None, alias="webViewLink")
    writers_can_share: bool | None = Field(None, alias="writersCanShare")

    parent_: None = Field(exclude=True, frozen=True, default=None)
    host_drive_: None = Field(exclude=True, frozen=True, default=None)

    _all_directories: list[Directory] = PrivateAttr(default_factory=list)
    _directories_mapped: bool = False
    _all_files: list[File] = PrivateAttr(default_factory=list)
    _files_mapped: bool = False

    @field_validator("kind", mode="before")
    @classmethod
    def _validate_kind(cls, value: str | None) -> str:
        """Set the kind to "drive#drive"."""

        # Drives are just a subtype of files, so `"drive#file"` is okay too
        if value not in (EntityKind.DRIVE, EntityKind.FILE):
            raise ValueError(f"Invalid kind for Drive: {value}")

        return EntityKind.DRIVE

    def _get_entity_by_id(
        self,
        cls: type[DriveSubEntity],
        entity_id: str,
    ) -> DriveSubEntity:
        """Get either a Directory or File by its ID.

        Args:
            cls (type): The class of the entity to get.
            entity_id (str): The ID of the entity to get.
        """
        file_fields = (
            "*"
            if self.google_client.item_metadata_retrieval == IMR.ON_INIT
            else "id, name, parents, mimeType, kind"
        )

        return cls.from_json_response(
            self.google_client.get_json_response(
                f"/files/{entity_id}",
                params={
                    "fields": file_fields,
                    "pageSize": None,
                },
            ),
            google_client=self.google_client,
            host_drive=self,
            _block_describe_call=True,
        )

    def get_directory_by_id(self, directory_id: str) -> Directory:
        """Get a directory by its ID.

        Args:
            directory_id (str): the ID of the directory to get

        Returns:
            Directory: the directory with the given ID
        """
        if isinstance(self._all_directories, list):
            for directory in self._all_directories:
                if directory.id == directory_id:
                    return directory

        return self._get_entity_by_id(Directory, directory_id)

    def get_file_by_id(self, file_id: str) -> File:
        """Get a file by its ID.

        Args:
            file_id (str): the ID of the file to get

        Returns:
            File: the file with the given ID
        """
        if isinstance(self._all_files, list):
            for file in self._all_files:
                if file.id == file_id:
                    return file

        return self._get_entity_by_id(File, file_id)

    def map(self, map_type: EntityType = EntityType.FILE) -> None:
        """Traverse the entire Drive to map its content.

        Args:
            map_type (EntityType, optional): the type of entity to map. Defaults to
                EntityType.FILE.
        """

        if (map_type == EntityType.DIRECTORY and self._directories_mapped is True) or (
            map_type == EntityType.FILE and self._files_mapped is True
        ):
            return

        # May as well get all fields in initial request if we're going to do it per
        # item anyway
        file_fields = (
            "*"
            if self.google_client.item_metadata_retrieval == IMR.ON_INIT
            else "id, name, parents, mimeType, kind"
        )

        params: dict[
            StrBytIntFlt,
            StrBytIntFlt | Iterable[StrBytIntFlt] | None,
        ] = {
            "pageSize": 1000,
            "fields": f"nextPageToken, files({file_fields})",
        }

        if map_type == EntityType.DIRECTORY:
            params["q"] = f"mimeType = '{Directory.MIME_TYPE}'"

        all_items = self.google_client.get_items(
            "/files",
            list_key="files",
            params=params,
        )
        all_files = []
        all_directories = []
        all_items = [item for item in all_items if "parents" in item]

        known_descendent_ids = [child.id for child in self.all_known_descendents]

        def build_sub_structure(
            parent_dir: _CanHaveChildren,
        ) -> None:
            """Build the sub-structure a given directory recursively.

            Args:
                parent_dir (_CanHaveChildren): the parent directory to build the
                    sub-structure for
            """
            nonlocal all_items

            remaining_items = []

            to_be_mapped = []
            for item in all_items:
                try:
                    if parent_dir.id != item["parents"][0]:  # type: ignore[index]
                        remaining_items.append(item)
                        continue
                except LookupError:  # pragma: no cover
                    continue

                if item["id"] in known_descendent_ids:
                    if item["mimeType"] == Directory.MIME_TYPE:
                        to_be_mapped.append(self.get_directory_by_id(item["id"]))  # type: ignore[arg-type]

                # Can't use `kind` here as it can be `drive#file` for directories
                elif item["mimeType"] == Directory.MIME_TYPE:
                    directory = Directory.from_json_response(
                        item,
                        google_client=self.google_client,
                        parent=parent_dir,
                        host_drive=self,
                        _block_describe_call=True,
                    )
                    parent_dir.add_child(directory)
                    all_directories.append(directory)
                    to_be_mapped.append(directory)
                else:
                    file = File.from_json_response(
                        item,
                        google_client=self.google_client,
                        parent=parent_dir,
                        host_drive=self,
                        _block_describe_call=True,
                    )
                    parent_dir.add_child(file)
                    all_files.append(file)

            all_items = remaining_items
            for directory in to_be_mapped:
                build_sub_structure(directory)

        build_sub_structure(self)

        self._all_directories = all_directories
        self._directories_mapped = True

        if map_type != EntityType.DIRECTORY:
            self._all_files = all_files
            self._files_mapped = True

    def search(
        self,
        term: str,
        /,
        *,
        entity_type: EntityType | None = None,
        max_results: int = 50,
        exact_match: bool = False,
        created_range: tuple[datetime, datetime] | None = None,
    ) -> list[File | Directory]:
        """Search for files and directories in the Drive.

        Args:
            term (str): the term to search for
            entity_type (EntityType | None, optional): the type of
                entity to search for. Defaults to None.
            max_results (int, optional): the maximum number of results to return.
                Defaults to 50.
            exact_match (bool, optional): whether to only return results that exactly
                match the search term. Defaults to False.
            created_range (tuple[datetime, datetime], optional): a tuple containing the
                start and end of the date range to search in. Defaults to None.

        Returns:
            list[File | Directory]: the files and directories that match the search
                term

        Raises:
            ValueError: if the given entity type is not supported
        """

        file_fields = (
            "*"
            if self.google_client.item_metadata_retrieval == IMR.ON_INIT
            else "id, name, parents, mimeType, kind"
        )

        params: dict[
            StrBytIntFlt,
            StrBytIntFlt | Iterable[StrBytIntFlt] | None,
        ] = {
            "pageSize": 1 if exact_match else min(max_results, 1000),
            "fields": f"nextPageToken, files({file_fields})",
        }

        query_conditions = [
            f"name = '{term}'" if exact_match else f"name contains '{term}'",
        ]

        if entity_type == EntityType.DIRECTORY:
            query_conditions.append(f"mimeType = '{Directory.MIME_TYPE}'")
        elif entity_type == EntityType.FILE:
            query_conditions.append(f"mimeType != '{Directory.MIME_TYPE}'")
        elif entity_type is not None:
            raise ValueError(
                "`entity_type` must be either EntityType.FILE or EntityType.DIRECTORY,"
                " or None to search for both",
            )

        if created_range:
            query_conditions.append(f"createdTime > '{created_range[0].isoformat()}'")
            query_conditions.append(f"createdTime <= '{created_range[1].isoformat()}'")

        params["q"] = " and ".join(query_conditions)

        all_items = self.google_client.get_items(
            "/files",
            list_key="files",
            params=params,
        )

        return [
            (
                Directory if item["mimeType"] == Directory.MIME_TYPE else File
            ).from_json_response(
                item,
                host_drive=self,
                google_client=self.google_client,
                _block_describe_call=True,
            )
            for item in all_items
        ]

    @property
    def all_known_descendents(self) -> list[Directory | File]:
        """Get all known children of this directory.

        No HTTP requests are made to get these children, so this may not be an
        exhaustive list.

        Returns:
            list[Directory | File]: The list of children.
        """
        if not isinstance(self._all_directories, list):
            self._all_directories = []

        if not isinstance(self._all_files, list):
            self._all_files = []

        return self._all_files + self._all_directories  # type: ignore[operator]

    @property
    def all_directories(self) -> list[Directory]:
        """Get all directories in the Drive."""

        if self._directories_mapped is not True:
            self.map(map_type=EntityType.DIRECTORY)

        return self._all_directories

    @property
    def all_files(self) -> list[File]:
        """Get all files in the Drive."""

        if self._files_mapped is not True:
            self.map()

        return self._all_files

    def __repr__(self) -> str:
        """Return a string representation of the directory."""
        return f"Drive(id={self.id!r}, name={self.name!r}"

all_directories: list[Directory] property

Get all directories in the Drive.

all_files: list[File] property

Get all files in the Drive.

all_known_descendents: list[Directory | File] property

Get all known children of this directory.

No HTTP requests are made to get these children, so this may not be an exhaustive list.

Returns:

Type Description
list[Directory | File]

list[Directory | File]: The list of children.

__repr__()

Return a string representation of the directory.

Source code in wg_utilities/clients/google_drive.py
1359
1360
1361
def __repr__(self) -> str:
    """Return a string representation of the directory."""
    return f"Drive(id={self.id!r}, name={self.name!r}"

get_directory_by_id(directory_id)

Get a directory by its ID.

Parameters:

Name Type Description Default
directory_id str

the ID of the directory to get

required

Returns:

Name Type Description
Directory Directory

the directory with the given ID

Source code in wg_utilities/clients/google_drive.py
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
def get_directory_by_id(self, directory_id: str) -> Directory:
    """Get a directory by its ID.

    Args:
        directory_id (str): the ID of the directory to get

    Returns:
        Directory: the directory with the given ID
    """
    if isinstance(self._all_directories, list):
        for directory in self._all_directories:
            if directory.id == directory_id:
                return directory

    return self._get_entity_by_id(Directory, directory_id)

get_file_by_id(file_id)

Get a file by its ID.

Parameters:

Name Type Description Default
file_id str

the ID of the file to get

required

Returns:

Name Type Description
File File

the file with the given ID

Source code in wg_utilities/clients/google_drive.py
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
def get_file_by_id(self, file_id: str) -> File:
    """Get a file by its ID.

    Args:
        file_id (str): the ID of the file to get

    Returns:
        File: the file with the given ID
    """
    if isinstance(self._all_files, list):
        for file in self._all_files:
            if file.id == file_id:
                return file

    return self._get_entity_by_id(File, file_id)

map(map_type=EntityType.FILE)

Traverse the entire Drive to map its content.

Parameters:

Name Type Description Default
map_type EntityType

the type of entity to map. Defaults to EntityType.FILE.

FILE
Source code in wg_utilities/clients/google_drive.py
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
def map(self, map_type: EntityType = EntityType.FILE) -> None:
    """Traverse the entire Drive to map its content.

    Args:
        map_type (EntityType, optional): the type of entity to map. Defaults to
            EntityType.FILE.
    """

    if (map_type == EntityType.DIRECTORY and self._directories_mapped is True) or (
        map_type == EntityType.FILE and self._files_mapped is True
    ):
        return

    # May as well get all fields in initial request if we're going to do it per
    # item anyway
    file_fields = (
        "*"
        if self.google_client.item_metadata_retrieval == IMR.ON_INIT
        else "id, name, parents, mimeType, kind"
    )

    params: dict[
        StrBytIntFlt,
        StrBytIntFlt | Iterable[StrBytIntFlt] | None,
    ] = {
        "pageSize": 1000,
        "fields": f"nextPageToken, files({file_fields})",
    }

    if map_type == EntityType.DIRECTORY:
        params["q"] = f"mimeType = '{Directory.MIME_TYPE}'"

    all_items = self.google_client.get_items(
        "/files",
        list_key="files",
        params=params,
    )
    all_files = []
    all_directories = []
    all_items = [item for item in all_items if "parents" in item]

    known_descendent_ids = [child.id for child in self.all_known_descendents]

    def build_sub_structure(
        parent_dir: _CanHaveChildren,
    ) -> None:
        """Build the sub-structure a given directory recursively.

        Args:
            parent_dir (_CanHaveChildren): the parent directory to build the
                sub-structure for
        """
        nonlocal all_items

        remaining_items = []

        to_be_mapped = []
        for item in all_items:
            try:
                if parent_dir.id != item["parents"][0]:  # type: ignore[index]
                    remaining_items.append(item)
                    continue
            except LookupError:  # pragma: no cover
                continue

            if item["id"] in known_descendent_ids:
                if item["mimeType"] == Directory.MIME_TYPE:
                    to_be_mapped.append(self.get_directory_by_id(item["id"]))  # type: ignore[arg-type]

            # Can't use `kind` here as it can be `drive#file` for directories
            elif item["mimeType"] == Directory.MIME_TYPE:
                directory = Directory.from_json_response(
                    item,
                    google_client=self.google_client,
                    parent=parent_dir,
                    host_drive=self,
                    _block_describe_call=True,
                )
                parent_dir.add_child(directory)
                all_directories.append(directory)
                to_be_mapped.append(directory)
            else:
                file = File.from_json_response(
                    item,
                    google_client=self.google_client,
                    parent=parent_dir,
                    host_drive=self,
                    _block_describe_call=True,
                )
                parent_dir.add_child(file)
                all_files.append(file)

        all_items = remaining_items
        for directory in to_be_mapped:
            build_sub_structure(directory)

    build_sub_structure(self)

    self._all_directories = all_directories
    self._directories_mapped = True

    if map_type != EntityType.DIRECTORY:
        self._all_files = all_files
        self._files_mapped = True

search(term, /, *, entity_type=None, max_results=50, exact_match=False, created_range=None)

Search for files and directories in the Drive.

Parameters:

Name Type Description Default
term str

the term to search for

required
entity_type EntityType | None

the type of entity to search for. Defaults to None.

None
max_results int

the maximum number of results to return. Defaults to 50.

50
exact_match bool

whether to only return results that exactly match the search term. Defaults to False.

False
created_range tuple[datetime, datetime]

a tuple containing the start and end of the date range to search in. Defaults to None.

None

Returns:

Type Description
list[File | Directory]

list[File | Directory]: the files and directories that match the search term

Raises:

Type Description
ValueError

if the given entity type is not supported

Source code in wg_utilities/clients/google_drive.py
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
def search(
    self,
    term: str,
    /,
    *,
    entity_type: EntityType | None = None,
    max_results: int = 50,
    exact_match: bool = False,
    created_range: tuple[datetime, datetime] | None = None,
) -> list[File | Directory]:
    """Search for files and directories in the Drive.

    Args:
        term (str): the term to search for
        entity_type (EntityType | None, optional): the type of
            entity to search for. Defaults to None.
        max_results (int, optional): the maximum number of results to return.
            Defaults to 50.
        exact_match (bool, optional): whether to only return results that exactly
            match the search term. Defaults to False.
        created_range (tuple[datetime, datetime], optional): a tuple containing the
            start and end of the date range to search in. Defaults to None.

    Returns:
        list[File | Directory]: the files and directories that match the search
            term

    Raises:
        ValueError: if the given entity type is not supported
    """

    file_fields = (
        "*"
        if self.google_client.item_metadata_retrieval == IMR.ON_INIT
        else "id, name, parents, mimeType, kind"
    )

    params: dict[
        StrBytIntFlt,
        StrBytIntFlt | Iterable[StrBytIntFlt] | None,
    ] = {
        "pageSize": 1 if exact_match else min(max_results, 1000),
        "fields": f"nextPageToken, files({file_fields})",
    }

    query_conditions = [
        f"name = '{term}'" if exact_match else f"name contains '{term}'",
    ]

    if entity_type == EntityType.DIRECTORY:
        query_conditions.append(f"mimeType = '{Directory.MIME_TYPE}'")
    elif entity_type == EntityType.FILE:
        query_conditions.append(f"mimeType != '{Directory.MIME_TYPE}'")
    elif entity_type is not None:
        raise ValueError(
            "`entity_type` must be either EntityType.FILE or EntityType.DIRECTORY,"
            " or None to search for both",
        )

    if created_range:
        query_conditions.append(f"createdTime > '{created_range[0].isoformat()}'")
        query_conditions.append(f"createdTime <= '{created_range[1].isoformat()}'")

    params["q"] = " and ".join(query_conditions)

    all_items = self.google_client.get_items(
        "/files",
        list_key="files",
        params=params,
    )

    return [
        (
            Directory if item["mimeType"] == Directory.MIME_TYPE else File
        ).from_json_response(
            item,
            host_drive=self,
            google_client=self.google_client,
            _block_describe_call=True,
        )
        for item in all_items
    ]

EntityKind

Bases: StrEnum

Enum for the different kinds of entities that can be returned by the API.

Source code in wg_utilities/clients/google_drive.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class EntityKind(StrEnum):
    """Enum for the different kinds of entities that can be returned by the API."""

    COMMENT = "drive#comment"
    COMMENT_REPLY = "drive#commentReply"
    CHANGE = "drive#change"
    CHANNEL = "drive#channel"
    DIRECTORY = "drive#folder"
    DRIVE = "drive#drive"
    FILE = "drive#file"
    FILE_LIST = "drive#fileList"
    LABEL = "drive#label"
    PERMISSION = "drive#permission"
    REPLY = "drive#reply"
    REVISION = "drive#revision"
    TEAM_DRIVE = "drive#teamDrive"
    TEAM_DRIVE_LIST = "drive#teamDriveList"
    USER = "drive#user"

EntityType

Bases: StrEnum

Enum for the different entity types contained within a Drive.

Source code in wg_utilities/clients/google_drive.py
44
45
46
47
48
class EntityType(StrEnum):
    """Enum for the different entity types contained within a Drive."""

    DIRECTORY = "directory"
    FILE = "file"

File

Bases: _GoogleDriveEntity

A file object within Google Drive.

Source code in wg_utilities/clients/google_drive.py
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
class File(_GoogleDriveEntity):
    """A file object within Google Drive."""

    kind: EntityKind = Field(alias="kind", default=EntityKind.FILE)

    # Optional, can be retrieved with the `describe` method or by getting the attribute
    app_properties: dict[str, str] = Field(default_factory=dict)
    capabilities: _DriveCapabilities | None = None
    content_hints: _ContentHints | None = Field(None, alias="contentHints")
    content_restrictions: list[_ContentRestriction] | None = Field(
        None,
        alias="contentRestrictions",
    )
    copy_requires_writer_permission: bool | None = Field(
        None,
        alias="copyRequiresWriterPermission",
    )
    created_time: datetime | None = Field(None, alias="createdTime")
    description: str | None = None
    drive_id: str | None = Field(None, alias="driveId")
    explicitly_trashed: bool | None = Field(None, alias="explicitlyTrashed")
    export_links: dict[str, str] = Field(alias="exportLinks", default_factory=dict)
    folder_color_rgb: str | None = Field(None, alias="folderColorRgb")
    file_extension: str | None = Field(None, alias="fileExtension")
    full_file_extension: str | None = Field(None, alias="fullFileExtension")
    has_augmented_permissions: bool | None = Field(None, alias="hasAugmentedPermissions")
    has_thumbnail: bool | None = Field(None, alias="hasThumbnail")
    head_revision_id: str | None = Field(None, alias="headRevisionId")
    icon_link: str | None = Field(None, alias="iconLink")
    image_media_metadata: _ImageMediaMetadata | None = Field(
        None,
        alias="imageMediaMetadata",
    )
    is_app_authorized: bool | None = Field(None, alias="isAppAuthorized")
    label_info: dict[Literal["labels"], list[_Label]] = Field(
        alias="labelInfo",
        default_factory=dict,
    )
    last_modifying_user: _User | None = Field(None, alias="lastModifyingUser")
    link_share_metadata: dict[
        Literal["securityUpdateEligible", "securityUpdateEnabled"],
        bool,
    ] = Field(alias="linkShareMetadata", default_factory=dict)
    md5_checksum: str | None = Field(None, alias="md5Checksum")
    modified_by_me: bool | None = Field(None, alias="modifiedByMe")
    modified_by_me_time: datetime | None = Field(None, alias="modifiedByMeTime")
    modified_time: datetime | None = Field(None, alias="modifiedTime")
    original_filename: str | None = Field(None, alias="originalFilename")
    owned_by_me: bool | None = Field(None, alias="ownedByMe")
    owners: list[_User] = Field(default_factory=list)
    parents: list[str]
    properties: dict[str, str] = Field(default_factory=dict)
    permissions: list[_Permission] = Field(default_factory=list)
    permission_ids: list[str] = Field(alias="permissionIds", default_factory=list)
    quota_bytes_used: float | None = Field(None, alias="quotaBytesUsed")
    resource_key: str | None = Field(None, alias="resourceKey")
    shared: bool | None = None
    sha1_checksum: str | None = Field(None, alias="sha1Checksum")
    sha256_checksum: str | None = Field(None, alias="sha256Checksum")
    shared_with_me_time: datetime | None = Field(None, alias="sharedWithMeTime")
    sharing_user: _User | None = Field(None, alias="sharingUser")
    shortcut_details: dict[
        Literal[
            "targetId",
            "targetMimeType",
            "targetResourceKey",
        ],
        str,
    ] = Field(alias="shortcutDetails", default_factory=dict)
    size: float | None = None
    spaces: list[str] = Field(default_factory=list)
    starred: bool | None = None
    thumbnail_link: str | None = Field(None, alias="thumbnailLink")
    thumbnail_version: int | None = Field(None, alias="thumbnailVersion")
    trashed: bool | None = None
    trashed_time: datetime | None = Field(None, alias="trashedTime")
    trashing_user: _User | None = Field(None, alias="trashingUser")
    version: int | None = None
    video_media_metadata: _VideoMediaMetadata | None = Field(
        None,
        alias="videoMediaMetadata",
    )
    viewed_by_me: bool | None = Field(None, alias="viewedByMe")
    viewed_by_me_time: datetime | None = Field(None, alias="viewedByMeTime")
    viewers_can_copy_content: bool | None = Field(None, alias="viewersCanCopyContent")
    web_content_link: str | None = Field(None, alias="webContentLink")
    web_view_link: str | None = Field(None, alias="webViewLink")
    writers_can_share: bool | None = Field(None, alias="writersCanShare")

    _description: JSONObj = PrivateAttr(default_factory=dict)
    host_drive_: Drive = Field(exclude=True)
    parent_: Directory | Drive | None = Field(exclude=True)

    def __getattribute__(self, name: str) -> Any:
        """Override the default `__getattribute__` to allow for lazy metadata loading.

        Args:
            name (str): The name of the attribute to retrieve.

        Returns:
            Any: The value of the attribute.
        """

        # If the attribute isn't a field, just return the value
        if (
            name in ("model_fields", "model_fields_set")
            or name not in self.model_fields
            or self.model_fields[name].exclude
        ):
            return super().__getattribute__(name)

        if name not in self.model_fields_set or not super().__getattribute__(name):
            # If IMR is enabled, load all metadata
            if self.google_client.item_metadata_retrieval == IMR.ON_FIRST_REQUEST:
                self.describe()
                return super().__getattribute__(name)

            # Otherwise just get the single field
            google_key = self.model_fields[name].alias or name

            res = self.google_client.get_json_response(
                f"/files/{self.id}",
                params={"fields": google_key, "pageSize": None},
            )
            setattr(self, name, res.pop(google_key, None))

            # I can't just return the value of `res.pop(google_key, None)` here because
            # it needs to go through Pydantic's validators first

        return super().__getattribute__(name)

    @field_validator("mime_type")
    @classmethod
    def _validate_mime_type(cls, mime_type: str) -> str:
        if mime_type == Directory.MIME_TYPE:
            raise ValueError("Use `Directory` class to create a directory.")

        return mime_type

    @field_validator("parents")
    @classmethod
    def _validate_parents(cls, parents: list[str]) -> list[str]:
        if len(parents) != 1:
            raise ValueError(f"A {cls.__name__} must have exactly one parent.")

        return parents

    @field_validator("parent_")
    @classmethod
    def _validate_parent_instance(
        cls,
        value: Directory | Drive | None,
        info: ValidationInfo,
    ) -> Directory | Drive | None:
        """Validate that the parent instance's ID matches the expected parent ID.

        Args:
            value (Directory, Drive): The parent instance.
            info (ValidationInfo): Object for extra validation information/data.

        Returns:
            Directory, Drive: The parent instance.

        Raises:
            ValueError: If the parent instance's ID doesn't match the expected parent
                ID.
        """

        if value is None:
            return value

        if value.id != info.data["parents"][0]:
            raise ValueError(
                f"Parent ID mismatch: {value.id} != {info.data['parents'][0]}",
            )

        return value

    def describe(self, *, force_update: bool = False) -> JSONObj:
        """Describe the file by requesting all available fields from the Drive API.

        Args:
            force_update (bool): re-pull the description from Google Drive, even if we
                already have the description locally

        Returns:
            dict: the description JSON for this file

        Raises:
            ValueError: if an unexpected field is returned from the Google Drive API.
        """

        if (
            force_update
            or not hasattr(self, "_description")
            or not isinstance(self._description, dict)
            or not self._description
        ):
            self._description = self.google_client.get_json_response(
                f"/files/{self.id}",
                params={"fields": "*", "pageSize": None},
            )

            for key, value in self._description.items():
                google_key = sub("([A-Z])", r"_\1", key).lower()

                try:
                    setattr(self, google_key, value)
                except ValueError as exc:
                    raise ValueError(
                        f"Received unexpected field {key!r} with value {value!r}"
                        " from Google Drive API",
                    ) from exc

        return self._description

    @property
    def parent(self) -> Directory | Drive:
        """Get the parent directory of this file.

        Returns:
            Directory: the parent directory of this file
        """
        if self.parent_ is None and isinstance(self, File | Directory):
            if (parent_id := self.parents[0]) == self.host_drive.id:
                self.parent_ = self.host_drive
            else:
                self.parent_ = self.host_drive.get_directory_by_id(parent_id)

            self.parent_.add_child(self)

        return self.parent_

    def __gt__(self, other: File) -> bool:
        """Compare two files by name."""
        return self.name.lower() > other.name.lower()

    def __lt__(self, other: File) -> bool:
        """Compare two files by name."""
        return self.name.lower() < other.name.lower()

    def __repr__(self) -> str:
        """Return a string representation of the file."""
        return f"File(id={self.id!r}, name={self.name!r})"

parent: Directory | Drive property

Get the parent directory of this file.

Returns:

Name Type Description
Directory Directory | Drive

the parent directory of this file

__getattribute__(name)

Override the default __getattribute__ to allow for lazy metadata loading.

Parameters:

Name Type Description Default
name str

The name of the attribute to retrieve.

required

Returns:

Name Type Description
Any Any

The value of the attribute.

Source code in wg_utilities/clients/google_drive.py
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def __getattribute__(self, name: str) -> Any:
    """Override the default `__getattribute__` to allow for lazy metadata loading.

    Args:
        name (str): The name of the attribute to retrieve.

    Returns:
        Any: The value of the attribute.
    """

    # If the attribute isn't a field, just return the value
    if (
        name in ("model_fields", "model_fields_set")
        or name not in self.model_fields
        or self.model_fields[name].exclude
    ):
        return super().__getattribute__(name)

    if name not in self.model_fields_set or not super().__getattribute__(name):
        # If IMR is enabled, load all metadata
        if self.google_client.item_metadata_retrieval == IMR.ON_FIRST_REQUEST:
            self.describe()
            return super().__getattribute__(name)

        # Otherwise just get the single field
        google_key = self.model_fields[name].alias or name

        res = self.google_client.get_json_response(
            f"/files/{self.id}",
            params={"fields": google_key, "pageSize": None},
        )
        setattr(self, name, res.pop(google_key, None))

        # I can't just return the value of `res.pop(google_key, None)` here because
        # it needs to go through Pydantic's validators first

    return super().__getattribute__(name)

__gt__(other)

Compare two files by name.

Source code in wg_utilities/clients/google_drive.py
849
850
851
def __gt__(self, other: File) -> bool:
    """Compare two files by name."""
    return self.name.lower() > other.name.lower()

__lt__(other)

Compare two files by name.

Source code in wg_utilities/clients/google_drive.py
853
854
855
def __lt__(self, other: File) -> bool:
    """Compare two files by name."""
    return self.name.lower() < other.name.lower()

__repr__()

Return a string representation of the file.

Source code in wg_utilities/clients/google_drive.py
857
858
859
def __repr__(self) -> str:
    """Return a string representation of the file."""
    return f"File(id={self.id!r}, name={self.name!r})"

describe(*, force_update=False)

Describe the file by requesting all available fields from the Drive API.

Parameters:

Name Type Description Default
force_update bool

re-pull the description from Google Drive, even if we already have the description locally

False

Returns:

Name Type Description
dict JSONObj

the description JSON for this file

Raises:

Type Description
ValueError

if an unexpected field is returned from the Google Drive API.

Source code in wg_utilities/clients/google_drive.py
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
def describe(self, *, force_update: bool = False) -> JSONObj:
    """Describe the file by requesting all available fields from the Drive API.

    Args:
        force_update (bool): re-pull the description from Google Drive, even if we
            already have the description locally

    Returns:
        dict: the description JSON for this file

    Raises:
        ValueError: if an unexpected field is returned from the Google Drive API.
    """

    if (
        force_update
        or not hasattr(self, "_description")
        or not isinstance(self._description, dict)
        or not self._description
    ):
        self._description = self.google_client.get_json_response(
            f"/files/{self.id}",
            params={"fields": "*", "pageSize": None},
        )

        for key, value in self._description.items():
            google_key = sub("([A-Z])", r"_\1", key).lower()

            try:
                setattr(self, google_key, value)
            except ValueError as exc:
                raise ValueError(
                    f"Received unexpected field {key!r} with value {value!r}"
                    " from Google Drive API",
                ) from exc

    return self._description

GoogleDriveClient

Bases: GoogleClient[JSONObj]

Custom client specifically for Google's Drive API.

Parameters:

Name Type Description Default
scopes list

a list of scopes the client can be given

None
creds_cache_path str

file path for where to cache credentials

None
Source code in wg_utilities/clients/google_drive.py
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
class GoogleDriveClient(GoogleClient[JSONObj]):
    """Custom client specifically for Google's Drive API.

    Args:
        scopes (list): a list of scopes the client can be given
        creds_cache_path (str): file path for where to cache credentials
    """

    BASE_URL = "https://www.googleapis.com/drive/v3"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/drive",
        "https://www.googleapis.com/auth/drive.file",
        "https://www.googleapis.com/auth/drive.readonly",
        "https://www.googleapis.com/auth/drive.metadata.readonly",
        "https://www.googleapis.com/auth/drive.appdata",
        "https://www.googleapis.com/auth/drive.metadata",
        "https://www.googleapis.com/auth/drive.photos.readonly",
    ]

    _my_drive: Drive

    def __init__(  # noqa: PLR0913
        self,
        *,
        client_id: str,
        client_secret: str,
        log_requests: bool = False,
        creds_cache_path: Path | None = None,
        creds_cache_dir: Path | None = None,
        scopes: list[str] | None = None,
        oauth_login_redirect_host: str = "localhost",
        oauth_redirect_uri_override: str | None = None,
        headless_auth_link_callback: Callable[[str], None] | None = None,
        use_existing_credentials_only: bool = False,
        validate_request_success: bool = True,
        item_metadata_retrieval: IMR = IMR.ON_FIRST_REQUEST,
    ):
        super().__init__(
            client_id=client_id,
            client_secret=client_secret,
            log_requests=log_requests,
            creds_cache_path=creds_cache_path,
            creds_cache_dir=creds_cache_dir,
            scopes=scopes or self.DEFAULT_SCOPES,
            oauth_login_redirect_host=oauth_login_redirect_host,
            oauth_redirect_uri_override=oauth_redirect_uri_override,
            headless_auth_link_callback=headless_auth_link_callback,
            use_existing_credentials_only=use_existing_credentials_only,
            base_url=self.BASE_URL,
            validate_request_success=validate_request_success,
        )

        self.item_metadata_retrieval = item_metadata_retrieval

    @property
    def my_drive(self) -> Drive:
        """User's personal Drive.

        Returns:
            Drive: the user's root directory/main Drive
        """
        if not hasattr(self, "_my_drive"):
            self._my_drive = Drive.from_json_response(
                self.get_json_response(
                    "/files/root",
                    params={"fields": "*", "pageSize": None},
                ),
                google_client=self,
            )

        return self._my_drive

    @property
    def shared_drives(self) -> list[Drive]:
        """Get a list of all shared drives.

        Returns:
            list: a list of Shared Drives the current user has access to
        """
        return [
            Drive.from_json_response(
                drive,
                google_client=self,
            )
            for drive in self.get_items(
                "/drives",
                list_key="drives",
                params={"fields": "*"},
            )
        ]

my_drive: Drive property

User's personal Drive.

Returns:

Name Type Description
Drive Drive

the user's root directory/main Drive

shared_drives: list[Drive] property

Get a list of all shared drives.

Returns:

Name Type Description
list list[Drive]

a list of Shared Drives the current user has access to

ItemMetadataRetrieval

Bases: StrEnum

The type of metadata retrieval to use for items.

Attributes:

Name Type Description
ON_DEMAND str

only retrieves single metadata items on demand. Best for reducing memory usage but makes most HTTP requests.

ON_FIRST_REQUEST str

retrieves all metadata items on the first request for any metadata value. Nice middle ground between memory usage and HTTP requests.

ON_INIT str

retrieves metadata on instance initialisation. Increases memory usage, makes the fewest HTTP requests. If combined with a Drive.map call, it can be used to preload all metadata for the entire Drive.

Source code in wg_utilities/clients/google_drive.py
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
class ItemMetadataRetrieval(StrEnum):
    """The type of metadata retrieval to use for items.

    Attributes:
        ON_DEMAND (str): only retrieves single metadata items on demand. Best for
            reducing memory usage but makes most HTTP requests.
        ON_FIRST_REQUEST (str): retrieves all metadata items on the first request for
            _any_ metadata value. Nice middle ground between memory usage and HTTP
            requests.
        ON_INIT (str): retrieves metadata on instance initialisation. Increases memory
            usage, makes the fewest HTTP requests. If combined with a `Drive.map` call,
            it can be used to preload all metadata for the entire Drive.
    """

    ON_DEMAND = "on_demand"
    ON_FIRST_REQUEST = "on_first_request"
    ON_INIT = "on_init"

google_fit

Custom client for interacting with Google's Fit API.

DataSource

Class for interacting with Google Fit Data Sources.

An example of a data source is Strava, Google Fit, MyFitnessPal, etc. The ID is something like "...weight", "...calories burnt".

Parameters:

Name Type Description Default
data_source_id str

the unique ID of the data source

required
google_client GoogleClient

a GoogleClient instance, needed for getting DataSource info

required
Source code in wg_utilities/clients/google_fit.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class DataSource:
    """Class for interacting with Google Fit Data Sources.

    An example of a data source is Strava, Google Fit, MyFitnessPal, etc. The ID is
    something _like_ "...weight", "...calories burnt".

    Args:
        data_source_id (str): the unique ID of the data source
        google_client (GoogleClient): a GoogleClient instance, needed for getting DataSource info
    """

    DP_VALUE_KEY_LOOKUP: ClassVar[DataPointValueKeyLookupInfo] = {
        "floatPoint": "fpVal",
        "integer": "intVal",
    }

    class DataPointValueKeyLookupInfo(TypedDict):
        """Typing info for the Data Point lookup dict."""

        floatPoint: Literal["fpVal"]
        integer: Literal["intVal"]

    def __init__(self, data_source_id: str, *, google_client: GoogleFitClient):
        self.data_source_id = data_source_id
        self.url = f"/users/me/dataSources/{self.data_source_id}"
        self.google_client = google_client

        self._description: _DataSourceDescriptionInfo

    @property
    def description(self) -> _DataSourceDescriptionInfo:
        """Description of the data source, in JSON format.

        Returns:
            dict: the JSON description of this data source
        """
        if not hasattr(self, "_description"):
            self._description = self.google_client.get_json_response(self.url)

        return self._description

    def sum_data_points_in_range(
        self,
        from_datetime: datetime | None = None,
        to_datetime: datetime | None = None,
    ) -> int:
        """Get the sum of data points in the given range.

        If no `from_datetime` is provided, it defaults to the start of today; if no
        `to_datetime` is provided then it defaults to now.

        Args:
            from_datetime (datetime): lower boundary for step count. Defaults to
                start of today.
            to_datetime (datetime): upper boundary for step count. Defaults to now.

        Returns:
            int: a sum of data points in the given range
        """

        from_nano = int(
            int(from_datetime.timestamp() * 1000000000)
            if from_datetime
            else int(
                datetime.today()
                .replace(hour=0, minute=0, second=0, microsecond=0)
                .timestamp()
                / DFUnit.NANOSECOND.value,
            ),
        )

        to_nano = int(
            int(to_datetime.timestamp() * 1000000000)
            if to_datetime
            else utcnow(DFUnit.NANOSECOND),
        )

        data_points: list[_GoogleFitDataPointInfo] = self.google_client.get_items(
            f"{self.url}/datasets/{from_nano}-{to_nano}",
            list_key="point",
        )

        count = 0
        for point in data_points:
            if (
                int(point["startTimeNanos"]) >= from_nano
                and int(point["endTimeNanos"]) <= to_nano
            ):
                count += point["value"][0][self.data_point_value_key]

        return count

    @property
    def data_type_field_format(
        self,
    ) -> Literal["floatPoint", "integer"]:
        """Field format of the data type.

        Original return type on here was as follows, think it was for other endpoints
        I haven't implemented

        ```
        Literal[
            "blob", "floatList", "floatPoint", "integer", "integerList", "map", "string"
        ]
        ```

        Returns:
            str: the field format of this data source (i.e. "integer" or "floatPoint")

        Raises:
            Exception: if more than 1 dataType field value is found
        """
        data_type_fields = self.description["dataType"]["field"]
        if len(data_type_fields) != 1:
            raise ValueError(
                f"Unexpected number of dataType fields ({len(data_type_fields)}): "
                + ", ".join(f["name"] for f in data_type_fields),
            )

        return data_type_fields[0]["format"]

    @property
    def data_point_value_key(self) -> Literal["fpVal", "intVal"]:
        """Key to use when looking up the value of a data point.

        Returns:
            str: the key to use when extracting data from a data point
        """

        return self.DP_VALUE_KEY_LOOKUP[self.data_type_field_format]

data_point_value_key: Literal['fpVal', 'intVal'] property

Key to use when looking up the value of a data point.

Returns:

Name Type Description
str Literal['fpVal', 'intVal']

the key to use when extracting data from a data point

data_type_field_format: Literal['floatPoint', 'integer'] property

Field format of the data type.

Original return type on here was as follows, think it was for other endpoints I haven't implemented

Literal[
    "blob", "floatList", "floatPoint", "integer", "integerList", "map", "string"
]

Returns:

Name Type Description
str Literal['floatPoint', 'integer']

the field format of this data source (i.e. "integer" or "floatPoint")

Raises:

Type Description
Exception

if more than 1 dataType field value is found

description: _DataSourceDescriptionInfo property

Description of the data source, in JSON format.

Returns:

Name Type Description
dict _DataSourceDescriptionInfo

the JSON description of this data source

DataPointValueKeyLookupInfo

Bases: TypedDict

Typing info for the Data Point lookup dict.

Source code in wg_utilities/clients/google_fit.py
57
58
59
60
61
class DataPointValueKeyLookupInfo(TypedDict):
    """Typing info for the Data Point lookup dict."""

    floatPoint: Literal["fpVal"]
    integer: Literal["intVal"]

sum_data_points_in_range(from_datetime=None, to_datetime=None)

Get the sum of data points in the given range.

If no from_datetime is provided, it defaults to the start of today; if no to_datetime is provided then it defaults to now.

Parameters:

Name Type Description Default
from_datetime datetime

lower boundary for step count. Defaults to start of today.

None
to_datetime datetime

upper boundary for step count. Defaults to now.

None

Returns:

Name Type Description
int int

a sum of data points in the given range

Source code in wg_utilities/clients/google_fit.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def sum_data_points_in_range(
    self,
    from_datetime: datetime | None = None,
    to_datetime: datetime | None = None,
) -> int:
    """Get the sum of data points in the given range.

    If no `from_datetime` is provided, it defaults to the start of today; if no
    `to_datetime` is provided then it defaults to now.

    Args:
        from_datetime (datetime): lower boundary for step count. Defaults to
            start of today.
        to_datetime (datetime): upper boundary for step count. Defaults to now.

    Returns:
        int: a sum of data points in the given range
    """

    from_nano = int(
        int(from_datetime.timestamp() * 1000000000)
        if from_datetime
        else int(
            datetime.today()
            .replace(hour=0, minute=0, second=0, microsecond=0)
            .timestamp()
            / DFUnit.NANOSECOND.value,
        ),
    )

    to_nano = int(
        int(to_datetime.timestamp() * 1000000000)
        if to_datetime
        else utcnow(DFUnit.NANOSECOND),
    )

    data_points: list[_GoogleFitDataPointInfo] = self.google_client.get_items(
        f"{self.url}/datasets/{from_nano}-{to_nano}",
        list_key="point",
    )

    count = 0
    for point in data_points:
        if (
            int(point["startTimeNanos"]) >= from_nano
            and int(point["endTimeNanos"]) <= to_nano
        ):
            count += point["value"][0][self.data_point_value_key]

    return count

GoogleFitClient

Bases: GoogleClient[Any]

Custom client for interacting with the Google Fit API.

See Also

GoogleClient: the base Google client, used for authentication and common functions

Source code in wg_utilities/clients/google_fit.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
class GoogleFitClient(GoogleClient[Any]):
    """Custom client for interacting with the Google Fit API.

    See Also:
        GoogleClient: the base Google client, used for authentication and common functions
    """

    BASE_URL = "https://www.googleapis.com/fitness/v1"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/fitness.activity.read",
        "https://www.googleapis.com/auth/fitness.body.read",
        "https://www.googleapis.com/auth/fitness.location.read",
        "https://www.googleapis.com/auth/fitness.nutrition.read",
    ]

    _data_sources: dict[str, DataSource]

    def get_data_source(self, data_source_id: str) -> DataSource:
        """Get a data source based on its UID.

        DataSource instances are cached for the lifetime of the GoogleClient instance

        Args:
            data_source_id (str): the UID of the data source

        Returns:
            DataSource: an instance, ready to use!
        """

        if (data_source := self.data_sources.get(data_source_id)) is None:
            data_source = DataSource(data_source_id=data_source_id, google_client=self)
            self.data_sources[data_source_id] = data_source

        return data_source

    @property
    def data_sources(self) -> dict[str, DataSource]:
        """Data sources available to this client.

        Returns:
            dict: a dict of data sources, keyed by their UID
        """
        if not hasattr(self, "_data_sources"):
            self._data_sources = {}

        return self._data_sources

data_sources: dict[str, DataSource] property

Data sources available to this client.

Returns:

Name Type Description
dict dict[str, DataSource]

a dict of data sources, keyed by their UID

get_data_source(data_source_id)

Get a data source based on its UID.

DataSource instances are cached for the lifetime of the GoogleClient instance

Parameters:

Name Type Description Default
data_source_id str

the UID of the data source

required

Returns:

Name Type Description
DataSource DataSource

an instance, ready to use!

Source code in wg_utilities/clients/google_fit.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def get_data_source(self, data_source_id: str) -> DataSource:
    """Get a data source based on its UID.

    DataSource instances are cached for the lifetime of the GoogleClient instance

    Args:
        data_source_id (str): the UID of the data source

    Returns:
        DataSource: an instance, ready to use!
    """

    if (data_source := self.data_sources.get(data_source_id)) is None:
        data_source = DataSource(data_source_id=data_source_id, google_client=self)
        self.data_sources[data_source_id] = data_source

    return data_source

google_photos

Custom client for interacting with Google's Photos API.

Album

Bases: GooglePhotosEntity

Class for Google Photos albums and their metadata/content.

Source code in wg_utilities/clients/google_photos.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
class Album(GooglePhotosEntity):
    """Class for Google Photos albums and their metadata/content."""

    cover_photo_base_url: str = Field(alias="coverPhotoBaseUrl")
    cover_photo_media_item_id: str = Field(alias="coverPhotoMediaItemId")
    is_writeable: bool | None = Field(None, alias="isWriteable")
    media_items_count: int = Field(alias="mediaItemsCount")
    share_info: _ShareInfoInfo | None = Field(None, alias="shareInfo")
    title: str

    _media_items: list[MediaItem]

    @field_validator("title")
    @classmethod
    def _validate_title(cls, value: str) -> str:
        """Validate the title of the album."""

        if not value:
            raise ValueError("Album title cannot be empty.")

        return value.strip()

    @property
    def media_items(self) -> list[MediaItem]:
        # noinspection GrazieInspection
        """List all media items in the album.

        Returns:
            list: a list of MediaItem instances, representing the contents of the album
        """

        if not hasattr(self, "_media_items"):
            self._media_items = [
                MediaItem.from_json_response(item, google_client=self.google_client)
                for item in self.google_client.get_items(
                    "/mediaItems:search",
                    method_override=post,
                    list_key="mediaItems",
                    params={"albumId": self.id, "pageSize": 100},
                )
            ]

        return self._media_items

    def __contains__(self, item: MediaItem) -> bool:
        """Check if the album contains the given media item."""

        return item.id in [media_item.id for media_item in self.media_items]

media_items: list[MediaItem] property

List all media items in the album.

Returns:

Name Type Description
list list[MediaItem]

a list of MediaItem instances, representing the contents of the album

__contains__(item)

Check if the album contains the given media item.

Source code in wg_utilities/clients/google_photos.py
158
159
160
161
def __contains__(self, item: MediaItem) -> bool:
    """Check if the album contains the given media item."""

    return item.id in [media_item.id for media_item in self.media_items]

AlbumJson

Bases: _GooglePhotosEntityJson

JSON representation of an Album.

Source code in wg_utilities/clients/google_photos.py
103
104
105
106
107
108
109
110
111
class AlbumJson(_GooglePhotosEntityJson):
    """JSON representation of an Album."""

    coverPhotoBaseUrl: str
    coverPhotoMediaItemId: str
    isWriteable: bool | None
    mediaItemsCount: int
    shareInfo: _ShareInfoInfo | None
    title: str

GooglePhotosClient

Bases: GoogleClient[GooglePhotosEntityJson]

Custom client for interacting with the Google Photos API.

See Also

GoogleClient: the base Google client, used for authentication and common functions

Source code in wg_utilities/clients/google_photos.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class GooglePhotosClient(GoogleClient[GooglePhotosEntityJson]):
    """Custom client for interacting with the Google Photos API.

    See Also:
        GoogleClient: the base Google client, used for authentication and common functions
    """

    BASE_URL = "https://photoslibrary.googleapis.com/v1"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "https://www.googleapis.com/auth/photoslibrary.readonly",
        "https://www.googleapis.com/auth/photoslibrary.appendonly",
        "https://www.googleapis.com/auth/photoslibrary.readonly.appcreateddata",
        "https://www.googleapis.com/auth/photoslibrary.edit.appcreateddata",
    ]

    _albums: list[Album]
    # Only really used to check if all album metadata has been fetched, not
    # available to the user (would still require caching all albums).
    _album_count: int

    def get_album_by_id(self, album_id: str) -> Album:
        """Get an album by its ID.

        Args:
            album_id (str): the ID of the album to fetch

        Returns:
            Album: the album with the given ID
        """

        if hasattr(self, "_albums"):
            for album in self._albums:
                if album.id == album_id:
                    return album

        album = Album.from_json_response(
            self.get_json_response(f"/albums/{album_id}", params={"pageSize": None}),
            google_client=self,
        )

        if not hasattr(self, "_albums"):
            self._albums = [album]
        else:
            self._albums.append(album)

        return album

    def get_album_by_name(self, album_name: str) -> Album:
        """Get an album definition from the Google API based on the album name.

        Args:
            album_name (str): the name of the album to find

        Returns:
            Album: an Album instance, with all metadata etc.

        Raises:
            FileNotFoundError: if the client can't find an album with the correct name
        """

        LOGGER.info("Getting metadata for album `%s`", album_name)
        for album in self.albums:
            if album.title == album_name:
                return album

        raise FileNotFoundError(f"Unable to find album with name {album_name!r}.")

    @property
    def albums(self) -> list[Album]:
        """List all albums in the active Google account.

        Returns:
            list: a list of Album instances
        """

        if not hasattr(self, "_albums"):
            self._albums = [
                Album.from_json_response(item, google_client=self)
                for item in self.get_items(
                    f"{self.BASE_URL}/albums",
                    list_key="albums",
                    params={"pageSize": 50},
                )
            ]
            self._album_count = len(self._albums)
        elif not hasattr(self, "_album_count"):
            album_ids = [album.id for album in self._albums]
            self._albums.extend(
                [
                    Album.from_json_response(item, google_client=self)
                    for item in self.get_items(
                        f"{self.BASE_URL}/albums",
                        list_key="albums",
                        params={"pageSize": 50},
                    )
                    if item["id"] not in album_ids
                ],
            )

            self._album_count = len(self._albums)

        return self._albums

albums: list[Album] property

List all albums in the active Google account.

Returns:

Name Type Description
list list[Album]

a list of Album instances

get_album_by_id(album_id)

Get an album by its ID.

Parameters:

Name Type Description Default
album_id str

the ID of the album to fetch

required

Returns:

Name Type Description
Album Album

the album with the given ID

Source code in wg_utilities/clients/google_photos.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
def get_album_by_id(self, album_id: str) -> Album:
    """Get an album by its ID.

    Args:
        album_id (str): the ID of the album to fetch

    Returns:
        Album: the album with the given ID
    """

    if hasattr(self, "_albums"):
        for album in self._albums:
            if album.id == album_id:
                return album

    album = Album.from_json_response(
        self.get_json_response(f"/albums/{album_id}", params={"pageSize": None}),
        google_client=self,
    )

    if not hasattr(self, "_albums"):
        self._albums = [album]
    else:
        self._albums.append(album)

    return album

get_album_by_name(album_name)

Get an album definition from the Google API based on the album name.

Parameters:

Name Type Description Default
album_name str

the name of the album to find

required

Returns:

Name Type Description
Album Album

an Album instance, with all metadata etc.

Raises:

Type Description
FileNotFoundError

if the client can't find an album with the correct name

Source code in wg_utilities/clients/google_photos.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def get_album_by_name(self, album_name: str) -> Album:
    """Get an album definition from the Google API based on the album name.

    Args:
        album_name (str): the name of the album to find

    Returns:
        Album: an Album instance, with all metadata etc.

    Raises:
        FileNotFoundError: if the client can't find an album with the correct name
    """

    LOGGER.info("Getting metadata for album `%s`", album_name)
    for album in self.albums:
        if album.title == album_name:
            return album

    raise FileNotFoundError(f"Unable to find album with name {album_name!r}.")

GooglePhotosEntity

Bases: BaseModelWithConfig

Generic base class for Google Photos entities.

Source code in wg_utilities/clients/google_photos.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class GooglePhotosEntity(BaseModelWithConfig):
    """Generic base class for Google Photos entities."""

    id: str
    product_url: str = Field(alias="productUrl")

    google_client: GooglePhotosClient = Field(exclude=True)

    @classmethod
    def from_json_response(
        cls,
        value: GooglePhotosEntityJson,
        *,
        google_client: GooglePhotosClient,
    ) -> Self:
        """Create an entity from a JSON response."""

        value_data: dict[str, Any] = {
            "google_client": google_client,
            **value,
        }

        return cls.model_validate(value_data)

from_json_response(value, *, google_client) classmethod

Create an entity from a JSON response.

Source code in wg_utilities/clients/google_photos.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@classmethod
def from_json_response(
    cls,
    value: GooglePhotosEntityJson,
    *,
    google_client: GooglePhotosClient,
) -> Self:
    """Create an entity from a JSON response."""

    value_data: dict[str, Any] = {
        "google_client": google_client,
        **value,
    }

    return cls.model_validate(value_data)

MediaItem

Bases: GooglePhotosEntity

Class for representing a MediaItem and its metadata/content.

Source code in wg_utilities/clients/google_photos.py
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
class MediaItem(GooglePhotosEntity):
    """Class for representing a MediaItem and its metadata/content."""

    base_url: str = Field(alias="baseUrl")
    contributor_info: dict[str, str] | None = Field(alias="contributorInfo", default=None)
    description: str | None = Field(default=None)
    filename: str
    media_metadata: _MediaItemMetadata = Field(alias="mediaMetadata")
    mime_type: str = Field(alias="mimeType")

    _local_path: Path

    def as_bytes(self, *, height_override: int = 0, width_override: int = 0) -> bytes:
        """MediaItem binary content - without the download."""
        height = height_override or self.height
        width = width_override or self.width

        param_str = {
            MediaType.IMAGE: f"=w{width}-h{height}",
            MediaType.VIDEO: "=dv",
        }.get(self.media_type, "")

        return self.google_client._get(
            f"{self.base_url}{param_str}",
            params={"pageSize": None},
        ).content

    def download(
        self,
        target_directory: Path | str = "",
        *,
        file_name_override: str | None = None,
        width_override: int = 0,
        height_override: int = 0,
        force_download: bool = False,
    ) -> Path:
        """Download the media item to local storage.

        Notes:
            The width/height overrides do not apply to videos.

        Args:
            target_directory (Path or str): the directory to download the file to.
                Defaults to the current working directory.
            file_name_override (str): the file name to use when downloading the file
            width_override (int): the width override to use when downloading the file
            height_override (int): the height override to use when downloading the file
            force_download (bool): flag for forcing a download, even if it exists
                locally already

        Returns:
                str: the path to the downloaded file (self.local_path)
        """

        if isinstance(target_directory, str):
            target_directory = (
                Path.cwd() if target_directory == "" else Path(target_directory)
            )

        self._local_path = (
            target_directory
            / self.creation_datetime.strftime("%Y/%m/%d")
            / (file_name_override or self.filename)
        )

        if self.local_path.is_file() and not force_download:
            LOGGER.warning(
                "File already exists at `%s` and `force_download` is `False`;"
                " skipping download.",
                self.local_path,
            )
        else:
            force_mkdir(self.local_path, path_is_file=True).write_bytes(
                self.as_bytes(
                    width_override=width_override,
                    height_override=height_override,
                ),
            )

        return self.local_path

    @property
    def binary_content(self) -> bytes:
        """MediaItem binary content.

        Opens the local copy of the file (downloading it first if necessary) and
        reads the binary content of it

        Returns:
            bytes: the binary content of the file
        """
        if not (self.local_path and self.local_path.is_file()):
            self.download()

        return self.local_path.read_bytes()

    @property
    def creation_datetime(self) -> datetime:
        """The datetime when the media item was created."""
        return self.media_metadata.creation_time

    @property
    def height(self) -> int:
        """MediaItem height.

        Returns:
            int: the media item's height
        """
        return self.media_metadata.height

    @property
    def is_downloaded(self) -> bool:
        """Whether the media item has been downloaded locally.

        Returns:
            bool: whether the media item has been downloaded locally
        """
        return bool(self.local_path and self.local_path.is_file())

    @property
    def local_path(self) -> Path:
        """The path which the is/would be stored at locally.

        Returns:
            Path: where the file is/will be stored
        """
        return getattr(self, "_local_path", Path("undefined"))

    @property
    def width(self) -> int:
        """MediaItem width.

        Returns:
            int: the media item's width
        """
        return self.media_metadata.width

    @property
    def media_type(self) -> MediaType:
        """Determines the media item's file type from the JSON.

        Returns:
            MediaType: the media type (image, video, etc.) for this item
        """
        try:
            return MediaType(self.mime_type.split("/")[0])
        except ValueError:
            return MediaType.UNKNOWN

binary_content: bytes property

MediaItem binary content.

Opens the local copy of the file (downloading it first if necessary) and reads the binary content of it

Returns:

Name Type Description
bytes bytes

the binary content of the file

creation_datetime: datetime property

The datetime when the media item was created.

height: int property

MediaItem height.

Returns:

Name Type Description
int int

the media item's height

is_downloaded: bool property

Whether the media item has been downloaded locally.

Returns:

Name Type Description
bool bool

whether the media item has been downloaded locally

local_path: Path property

The path which the is/would be stored at locally.

Returns:

Name Type Description
Path Path

where the file is/will be stored

media_type: MediaType property

Determines the media item's file type from the JSON.

Returns:

Name Type Description
MediaType MediaType

the media type (image, video, etc.) for this item

width: int property

MediaItem width.

Returns:

Name Type Description
int int

the media item's width

as_bytes(*, height_override=0, width_override=0)

MediaItem binary content - without the download.

Source code in wg_utilities/clients/google_photos.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def as_bytes(self, *, height_override: int = 0, width_override: int = 0) -> bytes:
    """MediaItem binary content - without the download."""
    height = height_override or self.height
    width = width_override or self.width

    param_str = {
        MediaType.IMAGE: f"=w{width}-h{height}",
        MediaType.VIDEO: "=dv",
    }.get(self.media_type, "")

    return self.google_client._get(
        f"{self.base_url}{param_str}",
        params={"pageSize": None},
    ).content

download(target_directory='', *, file_name_override=None, width_override=0, height_override=0, force_download=False)

Download the media item to local storage.

Notes

The width/height overrides do not apply to videos.

Parameters:

Name Type Description Default
target_directory Path or str

the directory to download the file to. Defaults to the current working directory.

''
file_name_override str

the file name to use when downloading the file

None
width_override int

the width override to use when downloading the file

0
height_override int

the height override to use when downloading the file

0
force_download bool

flag for forcing a download, even if it exists locally already

False

Returns:

Name Type Description
str Path

the path to the downloaded file (self.local_path)

Source code in wg_utilities/clients/google_photos.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
def download(
    self,
    target_directory: Path | str = "",
    *,
    file_name_override: str | None = None,
    width_override: int = 0,
    height_override: int = 0,
    force_download: bool = False,
) -> Path:
    """Download the media item to local storage.

    Notes:
        The width/height overrides do not apply to videos.

    Args:
        target_directory (Path or str): the directory to download the file to.
            Defaults to the current working directory.
        file_name_override (str): the file name to use when downloading the file
        width_override (int): the width override to use when downloading the file
        height_override (int): the height override to use when downloading the file
        force_download (bool): flag for forcing a download, even if it exists
            locally already

    Returns:
            str: the path to the downloaded file (self.local_path)
    """

    if isinstance(target_directory, str):
        target_directory = (
            Path.cwd() if target_directory == "" else Path(target_directory)
        )

    self._local_path = (
        target_directory
        / self.creation_datetime.strftime("%Y/%m/%d")
        / (file_name_override or self.filename)
    )

    if self.local_path.is_file() and not force_download:
        LOGGER.warning(
            "File already exists at `%s` and `force_download` is `False`;"
            " skipping download.",
            self.local_path,
        )
    else:
        force_mkdir(self.local_path, path_is_file=True).write_bytes(
            self.as_bytes(
                width_override=width_override,
                height_override=height_override,
            ),
        )

    return self.local_path

MediaItemJson

Bases: _GooglePhotosEntityJson

JSON representation of a Media Item (photo or video).

Source code in wg_utilities/clients/google_photos.py
164
165
166
167
168
169
170
171
172
class MediaItemJson(_GooglePhotosEntityJson):
    """JSON representation of a Media Item (photo or video)."""

    baseUrl: str
    contributorInfo: dict[str, str] | None
    description: str | None
    filename: str
    mediaMetadata: _MediaItemMetadata
    mimeType: str

MediaType

Bases: Enum

Enum for all potential media types.

Source code in wg_utilities/clients/google_photos.py
63
64
65
66
67
68
class MediaType(Enum):
    """Enum for all potential media types."""

    IMAGE = "image"
    VIDEO = "video"
    UNKNOWN = "unknown"

json_api_client

Generic no-auth JSON API client to simplify interactions.

JsonApiClient

Bases: Generic[GetJsonResponse]

Generic no-auth JSON API client to simplify interactions.

Sort of an SDK?

Source code in wg_utilities/clients/json_api_client.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
class JsonApiClient(Generic[GetJsonResponse]):
    """Generic no-auth JSON API client to simplify interactions.

    Sort of an SDK?
    """

    BASE_URL: str

    DEFAULT_PARAMS: ClassVar[
        dict[StrBytIntFlt, StrBytIntFlt | Iterable[StrBytIntFlt] | None]
    ] = {}

    def __init__(
        self,
        *,
        log_requests: bool = False,
        base_url: str | None = None,
        validate_request_success: bool = True,
    ):
        self.base_url = base_url or self.BASE_URL
        self.log_requests = log_requests
        self.validate_request_success = validate_request_success

    def _get(
        self,
        url: str,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> Response:
        """Wrap all GET requests to cover authentication, URL parsing, etc. etc.

        Args:
            url (str): the URL path to the endpoint (not necessarily including the
                base URL)
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): any headers to override the default headers
            timeout (float | tuple[float, float] | tuple[float, None] | None): the
                timeout for the request
            json (Any): the JSON to be passed in the HTTP request
            data (Any): the data to be passed in the HTTP request

        Returns:
            Response: the response from the HTTP request
        """
        return self._request(
            method=get,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    def _post(
        self,
        url: str,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> Response:
        """Wrap all POST requests to cover authentication, URL parsing, etc. etc.

        Args:
            url (str): the URL path to the endpoint (not necessarily including the
                base URL)
            json (dict): the data to be passed in the HTTP request
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): any headers to override the default headers
            timeout (float | tuple[float, float] | tuple[float, None] | None): the
                timeout for the request
            json (Any): the JSON to be passed in the HTTP request
            data (Any): the data to be passed in the HTTP request

        Returns:
            Response: the response from the HTTP request
        """
        return self._request(
            method=post,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    def _request(
        self,
        *,
        method: Callable[..., Response],
        url: str,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> Response:
        """Make a HTTP request.

        Args:
            method (Callable): the HTTP method to use
            url (str): the URL path to the endpoint (not necessarily including the
                base URL)
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): any headers to override the default headers
            timeout (float | tuple[float, float] | tuple[float, None] | None): the
                timeout for the request
            json (dict): the data to be passed in the HTTP request
            data (dict): the data to be passed in the HTTP request
        """
        if params is not None:
            params.update(
                {k: v for k, v in self.DEFAULT_PARAMS.items() if k not in params},
            )
        else:
            params = deepcopy(self.DEFAULT_PARAMS)

        params = {k: v for k, v in params.items() if v is not None}

        if url.startswith("/"):
            url = f"{self.base_url}{url}"

        if self.log_requests:
            LOGGER.debug(
                "%s %s: %s",
                method.__name__.upper(),
                url,
                dumps(params, default=str),
            )

        res = method(
            url,
            headers=(
                header_overrides if header_overrides is not None else self.request_headers
            ),
            params=params,
            timeout=timeout,
            json=json,
            data=data,
        )

        if self.validate_request_success:
            res.raise_for_status()

        return res

    def _request_json_response(
        self,
        *,
        method: Callable[..., Response],
        url: str,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> GetJsonResponse:
        res = self._request(
            method=method,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )
        if res.status_code == HTTPStatus.NO_CONTENT:
            return {}  # type: ignore[return-value]

        try:
            return res.json()  # type: ignore[no-any-return]
        except JSONDecodeError as exc:
            if not res.content:
                return {}  # type: ignore[return-value]

            raise ValueError(res.text) from exc

    def get_json_response(
        self,
        url: str,
        /,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> GetJsonResponse:
        """Get a simple JSON object from a URL.

        Args:
            url (str): the API endpoint to GET
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): headers to add to/overwrite the headers in
                `self.request_headers`. Setting this to an empty dict will erase all
                headers; `None` will use `self.request_headers`.
            timeout (float): How many seconds to wait for the server to send data
                before giving up
            json (dict): a JSON payload to pass in the request
            data (dict): a data payload to pass in the request

        Returns:
            dict: the JSON from the response
        """

        return self._request_json_response(
            method=get,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    def post_json_response(
        self,
        url: str,
        /,
        *,
        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = None,
        header_overrides: Mapping[str, str | bytes] | None = None,
        timeout: float | tuple[float, float] | tuple[float, None] | None = None,
        json: Any | None = None,
        data: Any | None = None,
    ) -> GetJsonResponse:
        """Get a simple JSON object from a URL from a POST request.

        Args:
            url (str): the API endpoint to GET
            params (dict): the parameters to be passed in the HTTP request
            header_overrides (dict): headers to add to/overwrite the headers in
                `self.request_headers`. Setting this to an empty dict will erase all
                headers; `None` will use `self.request_headers`.
            timeout (float): How many seconds to wait for the server to send data
                before giving up
            json (dict): a JSON payload to pass in the request
            data (dict): a data payload to pass in the request

        Returns:
            dict: the JSON from the response
        """

        return self._request_json_response(
            method=post,
            url=url,
            params=params,
            header_overrides=header_overrides,
            timeout=timeout,
            json=json,
            data=data,
        )

    @property
    def request_headers(self) -> dict[str, str]:
        """Header to be used in requests to the API.

        Returns:
            dict: auth headers for HTTP requests
        """
        return {
            "Content-Type": "application/json",
        }

request_headers: dict[str, str] property

Header to be used in requests to the API.

Returns:

Name Type Description
dict dict[str, str]

auth headers for HTTP requests

get_json_response(url, /, *, params=None, header_overrides=None, timeout=None, json=None, data=None)

Get a simple JSON object from a URL.

Parameters:

Name Type Description Default
url str

the API endpoint to GET

required
params dict

the parameters to be passed in the HTTP request

None
header_overrides dict

headers to add to/overwrite the headers in self.request_headers. Setting this to an empty dict will erase all headers; None will use self.request_headers.

None
timeout float

How many seconds to wait for the server to send data before giving up

None
json dict

a JSON payload to pass in the request

None
data dict

a data payload to pass in the request

None

Returns:

Name Type Description
dict GetJsonResponse

the JSON from the response

Source code in wg_utilities/clients/json_api_client.py
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def get_json_response(
    self,
    url: str,
    /,
    *,
    params: (
        dict[
            StrBytIntFlt,
            StrBytIntFlt | Iterable[StrBytIntFlt] | None,
        ]
        | None
    ) = None,
    header_overrides: Mapping[str, str | bytes] | None = None,
    timeout: float | None = None,
    json: Any | None = None,
    data: Any | None = None,
) -> GetJsonResponse:
    """Get a simple JSON object from a URL.

    Args:
        url (str): the API endpoint to GET
        params (dict): the parameters to be passed in the HTTP request
        header_overrides (dict): headers to add to/overwrite the headers in
            `self.request_headers`. Setting this to an empty dict will erase all
            headers; `None` will use `self.request_headers`.
        timeout (float): How many seconds to wait for the server to send data
            before giving up
        json (dict): a JSON payload to pass in the request
        data (dict): a data payload to pass in the request

    Returns:
        dict: the JSON from the response
    """

    return self._request_json_response(
        method=get,
        url=url,
        params=params,
        header_overrides=header_overrides,
        timeout=timeout,
        json=json,
        data=data,
    )

post_json_response(url, /, *, params=None, header_overrides=None, timeout=None, json=None, data=None)

Get a simple JSON object from a URL from a POST request.

Parameters:

Name Type Description Default
url str

the API endpoint to GET

required
params dict

the parameters to be passed in the HTTP request

None
header_overrides dict

headers to add to/overwrite the headers in self.request_headers. Setting this to an empty dict will erase all headers; None will use self.request_headers.

None
timeout float

How many seconds to wait for the server to send data before giving up

None
json dict

a JSON payload to pass in the request

None
data dict

a data payload to pass in the request

None

Returns:

Name Type Description
dict GetJsonResponse

the JSON from the response

Source code in wg_utilities/clients/json_api_client.py
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def post_json_response(
    self,
    url: str,
    /,
    *,
    params: (
        dict[
            StrBytIntFlt,
            StrBytIntFlt | Iterable[StrBytIntFlt] | None,
        ]
        | None
    ) = None,
    header_overrides: Mapping[str, str | bytes] | None = None,
    timeout: float | tuple[float, float] | tuple[float, None] | None = None,
    json: Any | None = None,
    data: Any | None = None,
) -> GetJsonResponse:
    """Get a simple JSON object from a URL from a POST request.

    Args:
        url (str): the API endpoint to GET
        params (dict): the parameters to be passed in the HTTP request
        header_overrides (dict): headers to add to/overwrite the headers in
            `self.request_headers`. Setting this to an empty dict will erase all
            headers; `None` will use `self.request_headers`.
        timeout (float): How many seconds to wait for the server to send data
            before giving up
        json (dict): a JSON payload to pass in the request
        data (dict): a data payload to pass in the request

    Returns:
        dict: the JSON from the response
    """

    return self._request_json_response(
        method=post,
        url=url,
        params=params,
        header_overrides=header_overrides,
        timeout=timeout,
        json=json,
        data=data,
    )

monzo

Custom client for interacting with Monzo's API.

Account

Bases: BaseModelWithConfig

Class for managing individual bank accounts.

Source code in wg_utilities/clients/monzo.py
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class Account(BaseModelWithConfig):
    """Class for managing individual bank accounts."""

    account_number: str
    closed: bool
    country_code: str
    created: datetime
    currency: Literal["GBP"]
    description: str
    id: str
    initial_balance: int | None = Field(None, validation_alias="balance")
    initial_balance_including_flexible_savings: int | None = Field(
        None,
        validation_alias="balance_including_flexible_savings",
    )
    initial_spend_today: int | None = Field(None, validation_alias="spend_today")
    initial_total_balance: int | None = Field(None, validation_alias="total_balance")
    owners: list[AccountOwner]
    payment_details: dict[str, dict[str, str]] | None = None
    sort_code: str = Field(min_length=6, max_length=6)
    type: Literal["uk_monzo_flex", "uk_retail", "uk_retail_joint"]

    monzo_client: MonzoClient = Field(exclude=True)
    balance_update_threshold: int = Field(15, exclude=True)
    last_balance_update: datetime = Field(datetime(1970, 1, 1), exclude=True)
    _balance_variables: BalanceVariables

    @field_validator("sort_code", mode="before")
    @classmethod
    def validate_sort_code(cls, sort_code: str | int) -> str:
        """Ensure that the sort code is a 6-digit integer.

        Represented as a string so leading zeroes aren't lost.
        """

        if isinstance(sort_code, int):
            sort_code = str(sort_code)

        if len(sort_code) != SORT_CODE_LEN:
            sort_code.ljust(SORT_CODE_LEN, "0")

        if not sort_code.isdigit():
            raise ValueError("Sort code must be a 6-digit integer")

        return sort_code

    @classmethod
    def from_json_response(
        cls,
        value: AccountJson,
        monzo_client: MonzoClient,
    ) -> Account:
        """Create an account from a JSON response."""

        value_data: dict[str, Any] = {
            "monzo_client": monzo_client,
            **value,
        }

        return cls.model_validate(value_data)

    def list_transactions(
        self,
        from_datetime: datetime | None = None,
        to_datetime: datetime | None = None,
        limit: int = 100,
    ) -> list[Transaction]:
        """List transactions for the account.

        Args:
            from_datetime (datetime, optional): the start of the time period to list
                transactions for. Defaults to None.
            to_datetime (datetime, optional): the end of the time period to list
                transactions for. Defaults to None.
            limit (int, optional): the maximum number of transactions to return.
                Defaults to 100.

        Returns:
            list[dict[str, object]]: the list of transactions
        """

        from_datetime = (
            from_datetime or (datetime.now(UTC) - timedelta(days=89))
        ).replace(microsecond=0, tzinfo=None)
        to_datetime = (to_datetime or datetime.now(UTC)).replace(
            microsecond=0,
            tzinfo=None,
        )

        return [
            Transaction(**item)
            for item in self.monzo_client.get_json_response(
                "/transactions",
                params={
                    "account_id": self.id,
                    "since": from_datetime.isoformat() + "Z",
                    "before": to_datetime.isoformat() + "Z",
                    "limit": limit,
                },
            )["transactions"]
        ]

    def update_balance_variables(self) -> None:
        """Update the balance-related instance attributes.

        Latest values from the API are used. This is called automatically when
        a balance property is accessed and the last update was more than
        `balance_update_threshold` minutes ago, or if it is None. Can also be called
        manually if required.
        """

        if not hasattr(self, "_balance_variables") or self.last_balance_update <= (
            datetime.now(UTC) - timedelta(minutes=self.balance_update_threshold)
        ):
            LOGGER.debug("Balance variable update threshold crossed, getting new values")

            self._balance_variables = BalanceVariables.model_validate(
                self.monzo_client.get_json_response(f"/balance?account_id={self.id}"),
            )

            self.last_balance_update = datetime.now(UTC)

    @property
    def balance(self) -> int | None:
        """Current balance of the account, in pence.

        Returns:
            float: the currently available balance of the account
        """
        return self.balance_variables.balance

    @property
    def balance_variables(self) -> BalanceVariables:
        """The balance variables for the account.

        Returns:
            BalanceVariables: the balance variables
        """
        self.update_balance_variables()

        return self._balance_variables

    @property
    def balance_including_flexible_savings(self) -> int | None:
        """Balance including flexible savings, in pence.

        Returns:
            float: the currently available balance of the account, including flexible
                savings pots
        """
        return self.balance_variables.balance_including_flexible_savings

    @property
    def spend_today(self) -> int | None:
        """Amount spent today, in pence.

        Returns:
            int: the amount spent from this account today (considered from approx
                4am onwards)
        """
        return self.balance_variables.spend_today

    @property
    def total_balance(self) -> int | None:
        """Total balance of the account, in pence.

        Returns:
            str: the sum of the currently available balance of the account and the
                combined total of all the user's pots
        """
        return self.balance_variables.total_balance

    def __eq__(self, other: object) -> bool:
        """Check if two accounts are equal."""
        if not isinstance(other, Account):
            return NotImplemented

        return self.id == other.id

    def __repr__(self) -> str:
        """Representation of the account."""
        return f"<Account {self.id}>"

balance: int | None property

Current balance of the account, in pence.

Returns:

Name Type Description
float int | None

the currently available balance of the account

balance_including_flexible_savings: int | None property

Balance including flexible savings, in pence.

Returns:

Name Type Description
float int | None

the currently available balance of the account, including flexible savings pots

balance_variables: BalanceVariables property

The balance variables for the account.

Returns:

Name Type Description
BalanceVariables BalanceVariables

the balance variables

spend_today: int | None property

Amount spent today, in pence.

Returns:

Name Type Description
int int | None

the amount spent from this account today (considered from approx 4am onwards)

total_balance: int | None property

Total balance of the account, in pence.

Returns:

Name Type Description
str int | None

the sum of the currently available balance of the account and the combined total of all the user's pots

__eq__(other)

Check if two accounts are equal.

Source code in wg_utilities/clients/monzo.py
390
391
392
393
394
395
def __eq__(self, other: object) -> bool:
    """Check if two accounts are equal."""
    if not isinstance(other, Account):
        return NotImplemented

    return self.id == other.id

__repr__()

Representation of the account.

Source code in wg_utilities/clients/monzo.py
397
398
399
def __repr__(self) -> str:
    """Representation of the account."""
    return f"<Account {self.id}>"

from_json_response(value, monzo_client) classmethod

Create an account from a JSON response.

Source code in wg_utilities/clients/monzo.py
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@classmethod
def from_json_response(
    cls,
    value: AccountJson,
    monzo_client: MonzoClient,
) -> Account:
    """Create an account from a JSON response."""

    value_data: dict[str, Any] = {
        "monzo_client": monzo_client,
        **value,
    }

    return cls.model_validate(value_data)

list_transactions(from_datetime=None, to_datetime=None, limit=100)

List transactions for the account.

Parameters:

Name Type Description Default
from_datetime datetime

the start of the time period to list transactions for. Defaults to None.

None
to_datetime datetime

the end of the time period to list transactions for. Defaults to None.

None
limit int

the maximum number of transactions to return. Defaults to 100.

100

Returns:

Type Description
list[Transaction]

list[dict[str, object]]: the list of transactions

Source code in wg_utilities/clients/monzo.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def list_transactions(
    self,
    from_datetime: datetime | None = None,
    to_datetime: datetime | None = None,
    limit: int = 100,
) -> list[Transaction]:
    """List transactions for the account.

    Args:
        from_datetime (datetime, optional): the start of the time period to list
            transactions for. Defaults to None.
        to_datetime (datetime, optional): the end of the time period to list
            transactions for. Defaults to None.
        limit (int, optional): the maximum number of transactions to return.
            Defaults to 100.

    Returns:
        list[dict[str, object]]: the list of transactions
    """

    from_datetime = (
        from_datetime or (datetime.now(UTC) - timedelta(days=89))
    ).replace(microsecond=0, tzinfo=None)
    to_datetime = (to_datetime or datetime.now(UTC)).replace(
        microsecond=0,
        tzinfo=None,
    )

    return [
        Transaction(**item)
        for item in self.monzo_client.get_json_response(
            "/transactions",
            params={
                "account_id": self.id,
                "since": from_datetime.isoformat() + "Z",
                "before": to_datetime.isoformat() + "Z",
                "limit": limit,
            },
        )["transactions"]
    ]

update_balance_variables()

Update the balance-related instance attributes.

Latest values from the API are used. This is called automatically when a balance property is accessed and the last update was more than balance_update_threshold minutes ago, or if it is None. Can also be called manually if required.

Source code in wg_utilities/clients/monzo.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
def update_balance_variables(self) -> None:
    """Update the balance-related instance attributes.

    Latest values from the API are used. This is called automatically when
    a balance property is accessed and the last update was more than
    `balance_update_threshold` minutes ago, or if it is None. Can also be called
    manually if required.
    """

    if not hasattr(self, "_balance_variables") or self.last_balance_update <= (
        datetime.now(UTC) - timedelta(minutes=self.balance_update_threshold)
    ):
        LOGGER.debug("Balance variable update threshold crossed, getting new values")

        self._balance_variables = BalanceVariables.model_validate(
            self.monzo_client.get_json_response(f"/balance?account_id={self.id}"),
        )

        self.last_balance_update = datetime.now(UTC)

validate_sort_code(sort_code) classmethod

Ensure that the sort code is a 6-digit integer.

Represented as a string so leading zeroes aren't lost.

Source code in wg_utilities/clients/monzo.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@field_validator("sort_code", mode="before")
@classmethod
def validate_sort_code(cls, sort_code: str | int) -> str:
    """Ensure that the sort code is a 6-digit integer.

    Represented as a string so leading zeroes aren't lost.
    """

    if isinstance(sort_code, int):
        sort_code = str(sort_code)

    if len(sort_code) != SORT_CODE_LEN:
        sort_code.ljust(SORT_CODE_LEN, "0")

    if not sort_code.isdigit():
        raise ValueError("Sort code must be a 6-digit integer")

    return sort_code

AccountJson

Bases: TypedDict

JSON representation of a Monzo account.

Source code in wg_utilities/clients/monzo.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
@final
class AccountJson(TypedDict):
    """JSON representation of a Monzo account."""

    account_number: str | None
    balance: float
    balance_including_flexible_savings: float
    closed: bool | None
    country_code: str
    created: str
    currency: Literal["GBP"]
    description: str
    id: str
    owners: list[AccountOwner]
    payment_details: dict[str, dict[str, str]] | None
    sort_code: str | None
    spend_today: float
    total_balance: float
    type: Literal["uk_monzo_flex", "uk_retail", "uk_retail_joint"]

AccountOwner

Bases: TypedDict

The owner of a Monzo account.

Source code in wg_utilities/clients/monzo.py
207
208
209
210
211
212
class AccountOwner(TypedDict):
    """The owner of a Monzo account."""

    preferred_first_name: str
    preferred_name: str
    user_id: str

BalanceVariables

Bases: BaseModelWithConfig

Variables for an account's balance summary.

Source code in wg_utilities/clients/monzo.py
194
195
196
197
198
199
200
201
202
203
204
class BalanceVariables(BaseModelWithConfig):
    """Variables for an account's balance summary."""

    balance: int
    balance_including_flexible_savings: int
    currency: Literal["GBP"]
    local_currency: str
    local_exchange_rate: int | float | None | Literal[""]
    local_spend: list[dict[str, int | str]]
    spend_today: int
    total_balance: int

MonzoClient

Bases: OAuthClient[MonzoGJR]

Custom client for interacting with Monzo's API.

Source code in wg_utilities/clients/monzo.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
class MonzoClient(OAuthClient[MonzoGJR]):
    """Custom client for interacting with Monzo's API."""

    ACCESS_TOKEN_ENDPOINT = "https://api.monzo.com/oauth2/token"  # noqa: S105
    AUTH_LINK_BASE = "https://auth.monzo.com"
    BASE_URL = "https://api.monzo.com"

    DEFAULT_PARAMS: ClassVar[
        dict[StrBytIntFlt, StrBytIntFlt | Iterable[StrBytIntFlt] | None]
    ] = {}

    _current_account: Account

    def deposit_into_pot(
        self,
        pot: Pot,
        amount_pence: int,
        dedupe_id: str | None = None,
    ) -> None:
        """Move money from the user's account into one of their pots.

        Args:
            pot (Pot): the target pot
            amount_pence (int): the amount of money to deposit, in pence
            dedupe_id (str): unique string used to de-duplicate deposits. Will be
                created if not provided
        """

        dedupe_id = dedupe_id or "|".join(
            [pot.id, str(amount_pence), str(utcnow(DTU.SECOND))],
        )

        res = put(
            f"{self.BASE_URL}/pots/{pot.id}/deposit",
            headers=self.request_headers,
            data={
                "source_account_id": self.current_account.id,
                "amount": amount_pence,
                "dedupe_id": dedupe_id,
            },
            timeout=10,
        )
        res.raise_for_status()

    def list_accounts(
        self,
        *,
        include_closed: bool = False,
        account_type: str | None = None,
    ) -> list[Account]:
        """Get a list of the user's accounts.

        Args:
            include_closed (bool): whether to include closed accounts in the response
            account_type (str): the type of account(s) to find; submitted as param in
                request

        Returns:
            list: Account instances, containing all related info
        """

        res = self.get_json_response(
            "/accounts",
            params={"account_type": account_type} if account_type else None,
        )

        return [
            Account.from_json_response(account, self)
            for account in res.get("accounts", [])
            if not account.get("closed", True) or include_closed
        ]

    def list_pots(self, *, include_deleted: bool = False) -> list[Pot]:
        """Get a list of the user's pots.

        Args:
            include_deleted (bool): whether to include deleted pots in the response

        Returns:
            list: Pot instances, containing all related info
        """

        res = self.get_json_response(
            "/pots",
            params={"current_account_id": self.current_account.id},
        )

        return [
            Pot(**pot)
            for pot in res.get("pots", [])
            if not pot.get("deleted", True) or include_deleted
        ]

    def get_pot_by_id(self, pot_id: str) -> Pot | None:
        """Get a pot from its ID.

        Args:
            pot_id (str): the ID of the pot to find

        Returns:
            Pot: the Pot instance
        """
        for pot in self.list_pots(include_deleted=True):
            if pot.id == pot_id:
                return pot

        return None

    def get_pot_by_name(
        self,
        pot_name: str,
        *,
        exact_match: bool = False,
        include_deleted: bool = False,
    ) -> Pot | None:
        """Get a pot from its name.

        Args:
            pot_name (str): the name of the pot to find
            exact_match (bool): if False, all pot names will be cleansed before
                evaluation
            include_deleted (bool): whether to include deleted pots in the response

        Returns:
            Pot: the Pot instance
        """
        if not exact_match:
            pot_name = cleanse_string(pot_name)

        for pot in self.list_pots(include_deleted=include_deleted):
            found_name = pot.name if exact_match else cleanse_string(pot.name)
            if found_name.lower() == pot_name.lower():
                return pot

        return None

    @property
    def current_account(self) -> Account:
        """Get the main account for the Monzo user.

        We assume there'll only be one main account per user.

        Returns:
            Account: the user's main account, instantiated
        """
        if not hasattr(self, "_current_account"):
            self._current_account = self.list_accounts(account_type="uk_retail")[0]

        return self._current_account

    @property
    def request_headers(self) -> dict[str, str]:
        """Header to be used in requests to the API.

        Returns:
            dict: auth headers for HTTP requests
        """
        return {
            "Authorization": f"Bearer {self.access_token}",
        }

current_account: Account property

Get the main account for the Monzo user.

We assume there'll only be one main account per user.

Returns:

Name Type Description
Account Account

the user's main account, instantiated

request_headers: dict[str, str] property

Header to be used in requests to the API.

Returns:

Name Type Description
dict dict[str, str]

auth headers for HTTP requests

deposit_into_pot(pot, amount_pence, dedupe_id=None)

Move money from the user's account into one of their pots.

Parameters:

Name Type Description Default
pot Pot

the target pot

required
amount_pence int

the amount of money to deposit, in pence

required
dedupe_id str

unique string used to de-duplicate deposits. Will be created if not provided

None
Source code in wg_utilities/clients/monzo.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
def deposit_into_pot(
    self,
    pot: Pot,
    amount_pence: int,
    dedupe_id: str | None = None,
) -> None:
    """Move money from the user's account into one of their pots.

    Args:
        pot (Pot): the target pot
        amount_pence (int): the amount of money to deposit, in pence
        dedupe_id (str): unique string used to de-duplicate deposits. Will be
            created if not provided
    """

    dedupe_id = dedupe_id or "|".join(
        [pot.id, str(amount_pence), str(utcnow(DTU.SECOND))],
    )

    res = put(
        f"{self.BASE_URL}/pots/{pot.id}/deposit",
        headers=self.request_headers,
        data={
            "source_account_id": self.current_account.id,
            "amount": amount_pence,
            "dedupe_id": dedupe_id,
        },
        timeout=10,
    )
    res.raise_for_status()

get_pot_by_id(pot_id)

Get a pot from its ID.

Parameters:

Name Type Description Default
pot_id str

the ID of the pot to find

required

Returns:

Name Type Description
Pot Pot | None

the Pot instance

Source code in wg_utilities/clients/monzo.py
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def get_pot_by_id(self, pot_id: str) -> Pot | None:
    """Get a pot from its ID.

    Args:
        pot_id (str): the ID of the pot to find

    Returns:
        Pot: the Pot instance
    """
    for pot in self.list_pots(include_deleted=True):
        if pot.id == pot_id:
            return pot

    return None

get_pot_by_name(pot_name, *, exact_match=False, include_deleted=False)

Get a pot from its name.

Parameters:

Name Type Description Default
pot_name str

the name of the pot to find

required
exact_match bool

if False, all pot names will be cleansed before evaluation

False
include_deleted bool

whether to include deleted pots in the response

False

Returns:

Name Type Description
Pot Pot | None

the Pot instance

Source code in wg_utilities/clients/monzo.py
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
def get_pot_by_name(
    self,
    pot_name: str,
    *,
    exact_match: bool = False,
    include_deleted: bool = False,
) -> Pot | None:
    """Get a pot from its name.

    Args:
        pot_name (str): the name of the pot to find
        exact_match (bool): if False, all pot names will be cleansed before
            evaluation
        include_deleted (bool): whether to include deleted pots in the response

    Returns:
        Pot: the Pot instance
    """
    if not exact_match:
        pot_name = cleanse_string(pot_name)

    for pot in self.list_pots(include_deleted=include_deleted):
        found_name = pot.name if exact_match else cleanse_string(pot.name)
        if found_name.lower() == pot_name.lower():
            return pot

    return None

list_accounts(*, include_closed=False, account_type=None)

Get a list of the user's accounts.

Parameters:

Name Type Description Default
include_closed bool

whether to include closed accounts in the response

False
account_type str

the type of account(s) to find; submitted as param in request

None

Returns:

Name Type Description
list list[Account]

Account instances, containing all related info

Source code in wg_utilities/clients/monzo.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def list_accounts(
    self,
    *,
    include_closed: bool = False,
    account_type: str | None = None,
) -> list[Account]:
    """Get a list of the user's accounts.

    Args:
        include_closed (bool): whether to include closed accounts in the response
        account_type (str): the type of account(s) to find; submitted as param in
            request

    Returns:
        list: Account instances, containing all related info
    """

    res = self.get_json_response(
        "/accounts",
        params={"account_type": account_type} if account_type else None,
    )

    return [
        Account.from_json_response(account, self)
        for account in res.get("accounts", [])
        if not account.get("closed", True) or include_closed
    ]

list_pots(*, include_deleted=False)

Get a list of the user's pots.

Parameters:

Name Type Description Default
include_deleted bool

whether to include deleted pots in the response

False

Returns:

Name Type Description
list list[Pot]

Pot instances, containing all related info

Source code in wg_utilities/clients/monzo.py
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def list_pots(self, *, include_deleted: bool = False) -> list[Pot]:
    """Get a list of the user's pots.

    Args:
        include_deleted (bool): whether to include deleted pots in the response

    Returns:
        list: Pot instances, containing all related info
    """

    res = self.get_json_response(
        "/pots",
        params={"current_account_id": self.current_account.id},
    )

    return [
        Pot(**pot)
        for pot in res.get("pots", [])
        if not pot.get("deleted", True) or include_deleted
    ]

MonzoGJR

Bases: TypedDict

The response type for MonzoClient.get_json_response.

Source code in wg_utilities/clients/monzo.py
430
431
432
433
434
435
class MonzoGJR(TypedDict):
    """The response type for `MonzoClient.get_json_response`."""

    accounts: list[AccountJson]
    pots: list[PotJson]
    transactions: list[TransactionJson]

Pot

Bases: BaseModelWithConfig

Read-only class for Monzo pots.

Source code in wg_utilities/clients/monzo.py
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
class Pot(BaseModelWithConfig):
    """Read-only class for Monzo pots."""

    available_for_bills: bool
    balance: float
    charity_id: str | None = None
    cover_image_url: str
    created: datetime
    currency: str
    current_account_id: str
    deleted: bool
    goal_amount: float | None = None
    has_virtual_cards: bool
    id: str
    is_tax_pot: bool
    isa_wrapper: str
    lock_type: Literal["until_date"] | None = None
    locked: bool
    locked_until: datetime | None = None
    name: str
    product_id: str
    round_up: bool
    round_up_multiplier: float | None = None
    style: str
    type: str
    updated: datetime

PotJson

Bases: TypedDict

JSON representation of a pot.

Yes, this and the Pot class could've been replaced by Pydantic's create_model_from_typeddict, but it doesn't play nice with mypy :(

Source code in wg_utilities/clients/monzo.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@final
class PotJson(TypedDict):
    """JSON representation of a pot.

    Yes, this and the `Pot` class could've been replaced by Pydantic's
    `create_model_from_typeddict`, but it doesn't play nice with mypy :(
    """

    available_for_bills: bool
    balance: float
    charity_id: str | None
    cover_image_url: str
    created: datetime  # N.B. `str` actually, just parsed as `datetime`
    currency: str
    current_account_id: str
    deleted: bool
    goal_amount: float | None
    has_virtual_cards: bool
    id: str
    is_tax_pot: bool
    isa_wrapper: str
    lock_type: Literal["until_date"] | None
    locked: bool
    locked_until: datetime | None
    name: str
    product_id: str
    round_up: bool
    round_up_multiplier: float | None
    style: str
    type: str
    updated: datetime  # N.B. `str` actually, just parsed as `datetime`

Transaction

Bases: BaseModelWithConfig

Pydantic representation of a transaction.

Source code in wg_utilities/clients/monzo.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
class Transaction(BaseModelWithConfig):
    """Pydantic representation of a transaction."""

    account_id: str
    amount: int
    amount_is_pending: bool
    atm_fees_detailed: dict[str, int | str | None] | None = None
    attachments: None = None
    can_add_to_tab: bool
    can_be_excluded_from_breakdown: bool
    can_be_made_subscription: bool
    can_match_transactions_in_categorization: bool
    can_split_the_bill: bool
    categories: dict[
        TransactionCategory,
        int,
    ]
    category: TransactionCategory
    counterparty: dict[str, str]
    created: datetime
    currency: str
    decline_reason: str | None = None
    dedupe_id: str
    description: str
    fees: dict[str, Any] | None = None
    id: str
    include_in_spending: bool
    international: bool | None = None
    is_load: bool
    labels: list[str] | None = None
    local_amount: int
    local_currency: str
    merchant: str | None
    merchant_feedback_uri: str | None = None
    metadata: dict[str, str]
    notes: str
    originator: bool
    parent_account_id: str
    scheme: str
    settled: str
    tab: dict[str, object] | None = None
    updated: datetime
    user_id: str

TransactionJson

Bases: TypedDict

JSON representation of a transaction.

Same as above RE: Pydantic's create_model_from_typeddict.

Source code in wg_utilities/clients/monzo.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@final
class TransactionJson(TypedDict):
    """JSON representation of a transaction.

    Same as above RE: Pydantic's `create_model_from_typeddict`.
    """

    account_id: str
    amount: int
    amount_is_pending: bool
    atm_fees_detailed: dict[str, int | str | None] | None
    attachments: None
    can_add_to_tab: bool
    can_be_excluded_from_breakdown: bool
    can_be_made_subscription: bool
    can_match_transactions_in_categorization: bool
    can_split_the_bill: bool
    categories: dict[
        TransactionCategory,
        int,
    ]
    category: TransactionCategory
    counterparty: dict[str, str]
    created: datetime
    currency: str
    decline_reason: str | None
    dedupe_id: str
    description: str
    fees: dict[str, Any] | None
    id: str
    include_in_spending: bool
    international: bool | None
    is_load: bool
    labels: list[str] | None
    local_amount: int
    local_currency: str
    merchant: str | None
    merchant_feedback_uri: str | None
    metadata: dict[str, str]
    notes: str
    originator: bool
    parent_account_id: str
    scheme: str
    settled: str
    tab: dict[str, object] | None
    updated: datetime
    user_id: str

oauth_client

Generic OAuth client to allow for reusable authentication flows/checks etc.

BaseModelWithConfig

Bases: BaseModel

Reusable BaseModel with Config to apply to all subclasses.

Source code in wg_utilities/clients/oauth_client.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class BaseModelWithConfig(BaseModel):
    """Reusable `BaseModel` with Config to apply to all subclasses."""

    model_config: ClassVar[ConfigDict] = ConfigDict(
        arbitrary_types_allowed=True,
        extra="ignore",
        validate_assignment=True,
    )

    def model_dump(  # noqa: PLR0913
        self,
        *,
        mode: Literal["json", "python"] | str = "python",
        include: IncEx | None = None,
        exclude: IncEx | None = None,
        context: dict[str, Any] | None = None,
        by_alias: bool = True,
        exclude_unset: bool = True,
        exclude_defaults: bool = False,
        exclude_none: bool = False,
        round_trip: bool = False,
        warnings: bool | Literal["none", "warn", "error"] = True,
        serialize_as_any: bool = False,
    ) -> dict[str, Any]:
        """Create a dictionary representation of the model.

        Overrides the standard `BaseModel.dict` method, so we can always return the
        dict with the same field names it came in with, and exclude any null values
        that have been added when parsing

        Original documentation is here:
          - https://docs.pydantic.dev/latest/usage/serialization/#modelmodel_dump

        Overridden Parameters:
            by_alias: False -> True
            exclude_unset: False -> True
        """

        return super().model_dump(
            mode=mode,
            include=include,
            exclude=exclude,
            context=context,
            by_alias=by_alias,
            exclude_unset=exclude_unset,
            exclude_defaults=exclude_defaults,
            exclude_none=exclude_none,
            round_trip=round_trip,
            warnings=warnings,
            serialize_as_any=serialize_as_any,
        )

    def model_dump_json(  # noqa: PLR0913
        self,
        *,
        indent: int | None = None,
        include: IncEx | None = None,
        exclude: IncEx | None = None,
        context: dict[str, Any] | None = None,
        by_alias: bool = True,
        exclude_unset: bool = True,
        exclude_defaults: bool = False,
        exclude_none: bool = False,
        round_trip: bool = False,
        warnings: bool | Literal["none", "warn", "error"] = True,
        serialize_as_any: bool = False,
    ) -> str:
        """Create a JSON string representation of the model.

        Overrides the standard `BaseModel.json` method, so we can always return the
        dict with the same field names it came in with, and exclude any null values
        that have been added when parsing

        Original documentation is here:
          - https://docs.pydantic.dev/latest/usage/serialization/#modelmodel_dump_json

        Overridden Parameters:
            by_alias: False -> True
            exclude_unset: False -> True
        """

        return super().model_dump_json(
            indent=indent,
            include=include,
            exclude=exclude,
            context=context,
            by_alias=by_alias,
            exclude_unset=exclude_unset,
            exclude_defaults=exclude_defaults,
            exclude_none=exclude_none,
            round_trip=round_trip,
            warnings=warnings,
            serialize_as_any=serialize_as_any,
        )

model_dump(*, mode='python', include=None, exclude=None, context=None, by_alias=True, exclude_unset=True, exclude_defaults=False, exclude_none=False, round_trip=False, warnings=True, serialize_as_any=False)

Create a dictionary representation of the model.

Overrides the standard BaseModel.dict method, so we can always return the dict with the same field names it came in with, and exclude any null values that have been added when parsing

Original documentation is here
  • https://docs.pydantic.dev/latest/usage/serialization/#modelmodel_dump
Overridden Parameters

by_alias: False -> True exclude_unset: False -> True

Source code in wg_utilities/clients/oauth_client.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def model_dump(  # noqa: PLR0913
    self,
    *,
    mode: Literal["json", "python"] | str = "python",
    include: IncEx | None = None,
    exclude: IncEx | None = None,
    context: dict[str, Any] | None = None,
    by_alias: bool = True,
    exclude_unset: bool = True,
    exclude_defaults: bool = False,
    exclude_none: bool = False,
    round_trip: bool = False,
    warnings: bool | Literal["none", "warn", "error"] = True,
    serialize_as_any: bool = False,
) -> dict[str, Any]:
    """Create a dictionary representation of the model.

    Overrides the standard `BaseModel.dict` method, so we can always return the
    dict with the same field names it came in with, and exclude any null values
    that have been added when parsing

    Original documentation is here:
      - https://docs.pydantic.dev/latest/usage/serialization/#modelmodel_dump

    Overridden Parameters:
        by_alias: False -> True
        exclude_unset: False -> True
    """

    return super().model_dump(
        mode=mode,
        include=include,
        exclude=exclude,
        context=context,
        by_alias=by_alias,
        exclude_unset=exclude_unset,
        exclude_defaults=exclude_defaults,
        exclude_none=exclude_none,
        round_trip=round_trip,
        warnings=warnings,
        serialize_as_any=serialize_as_any,
    )

model_dump_json(*, indent=None, include=None, exclude=None, context=None, by_alias=True, exclude_unset=True, exclude_defaults=False, exclude_none=False, round_trip=False, warnings=True, serialize_as_any=False)

Create a JSON string representation of the model.

Overrides the standard BaseModel.json method, so we can always return the dict with the same field names it came in with, and exclude any null values that have been added when parsing

Original documentation is here
  • https://docs.pydantic.dev/latest/usage/serialization/#modelmodel_dump_json
Overridden Parameters

by_alias: False -> True exclude_unset: False -> True

Source code in wg_utilities/clients/oauth_client.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def model_dump_json(  # noqa: PLR0913
    self,
    *,
    indent: int | None = None,
    include: IncEx | None = None,
    exclude: IncEx | None = None,
    context: dict[str, Any] | None = None,
    by_alias: bool = True,
    exclude_unset: bool = True,
    exclude_defaults: bool = False,
    exclude_none: bool = False,
    round_trip: bool = False,
    warnings: bool | Literal["none", "warn", "error"] = True,
    serialize_as_any: bool = False,
) -> str:
    """Create a JSON string representation of the model.

    Overrides the standard `BaseModel.json` method, so we can always return the
    dict with the same field names it came in with, and exclude any null values
    that have been added when parsing

    Original documentation is here:
      - https://docs.pydantic.dev/latest/usage/serialization/#modelmodel_dump_json

    Overridden Parameters:
        by_alias: False -> True
        exclude_unset: False -> True
    """

    return super().model_dump_json(
        indent=indent,
        include=include,
        exclude=exclude,
        context=context,
        by_alias=by_alias,
        exclude_unset=exclude_unset,
        exclude_defaults=exclude_defaults,
        exclude_none=exclude_none,
        round_trip=round_trip,
        warnings=warnings,
        serialize_as_any=serialize_as_any,
    )

OAuthClient

Bases: JsonApiClient[GetJsonResponse]

Custom client for interacting with OAuth APIs.

Includes all necessary/basic authentication functionality

Source code in wg_utilities/clients/oauth_client.py
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
class OAuthClient(JsonApiClient[GetJsonResponse]):
    """Custom client for interacting with OAuth APIs.

    Includes all necessary/basic authentication functionality
    """

    ACCESS_TOKEN_ENDPOINT: str
    AUTH_LINK_BASE: str

    ACCESS_TOKEN_EXPIRY_THRESHOLD = 150

    DEFAULT_CACHE_DIR = getenv("WG_UTILITIES_CREDS_CACHE_DIR")

    DEFAULT_SCOPES: ClassVar[list[str]] = []

    HEADLESS_MODE = getenv("WG_UTILITIES_HEADLESS_MODE", "0") == "1"

    _credentials: OAuthCredentials
    _temp_auth_server: TempAuthServer

    def __init__(  # noqa: PLR0913
        self,
        *,
        client_id: str | None = None,
        client_secret: str | None = None,
        log_requests: bool = False,
        creds_cache_path: Path | None = None,
        creds_cache_dir: Path | None = None,
        scopes: list[str] | None = None,
        oauth_login_redirect_host: str = "localhost",
        oauth_redirect_uri_override: str | None = None,
        headless_auth_link_callback: Callable[[str], None] | None = None,
        use_existing_credentials_only: bool = False,
        access_token_endpoint: str | None = None,
        auth_link_base: str | None = None,
        base_url: str | None = None,
        validate_request_success: bool = True,
    ):
        """Initialise the client.

        Args:
            client_id (str, optional): the client ID for the API. Defaults to None.
            client_secret (str, optional): the client secret for the API. Defaults to
                None.
            log_requests (bool, optional): whether to log requests. Defaults to False.
            creds_cache_path (Path, optional): the path to the credentials cache file.
                Defaults to None. Overrides `creds_cache_dir`.
            creds_cache_dir (Path, optional): the path to the credentials cache directory.
                Overrides environment variable `WG_UTILITIES_CREDS_CACHE_DIR`. Defaults to
                None.
            scopes (list[str], optional): the scopes to request when authenticating.
                Defaults to None.
            oauth_login_redirect_host (str, optional): the host to redirect to after
                authenticating. Defaults to "localhost".
            oauth_redirect_uri_override (str, optional): override the redirect URI
                specified in the OAuth flow. Defaults to None.
            headless_auth_link_callback (Callable[[str], None], optional): callback to
                send the auth link to when running in headless mode. Defaults to None.
            use_existing_credentials_only (bool, optional): whether to only use existing
                credentials, and not attempt to authenticate. Defaults to False.
            access_token_endpoint (str, optional): the endpoint to use to get an access
                token. Defaults to None.
            auth_link_base (str, optional): the base URL to use to generate the auth
                link. Defaults to None.
            base_url (str, optional): the base URL to use for API requests. Defaults to
                None.
            validate_request_success (bool, optional): whether to validate that the
                request was successful. Defaults to True.
        """
        super().__init__(
            log_requests=log_requests,
            base_url=base_url,
            validate_request_success=validate_request_success,
        )

        self._client_id = client_id
        self._client_secret = client_secret
        self.access_token_endpoint = access_token_endpoint or self.ACCESS_TOKEN_ENDPOINT
        self.auth_link_base = auth_link_base or self.AUTH_LINK_BASE
        self.oauth_login_redirect_host = oauth_login_redirect_host
        self.oauth_redirect_uri_override = oauth_redirect_uri_override
        self.headless_auth_link_callback = headless_auth_link_callback
        self.use_existing_credentials_only = use_existing_credentials_only

        if creds_cache_path:
            self._creds_cache_path: Path | None = creds_cache_path
            self._creds_cache_dir: Path | None = None
        elif creds_cache_dir:
            self._creds_cache_path = None
            self._creds_cache_dir = creds_cache_dir
        else:
            self._creds_cache_path = None
            if self.DEFAULT_CACHE_DIR:
                self._creds_cache_dir = Path(self.DEFAULT_CACHE_DIR)
            else:
                self._creds_cache_dir = None

        self.scopes = scopes or self.DEFAULT_SCOPES

        if self._creds_cache_path:
            self._load_local_credentials()

    def _load_local_credentials(self) -> bool:
        """Load credentials from the local cache.

        Returns:
            bool: True if the credentials were loaded successfully, False otherwise
        """
        try:
            self._credentials = OAuthCredentials.model_validate_json(
                self.creds_cache_path.read_text(),
            )
        except FileNotFoundError:
            return False

        return True

    def delete_creds_file(self) -> None:
        """Delete the local creds file."""
        self.creds_cache_path.unlink(missing_ok=True)

    def refresh_access_token(self) -> None:
        """Refresh access token."""

        if not hasattr(self, "_credentials") and not self._load_local_credentials():
            # If we don't have any credentials, we can't refresh the access token -
            # perform first time login and leave it at that
            self.run_first_time_login()
            return

        LOGGER.info("Refreshing access token")

        payload = {
            "grant_type": "refresh_token",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "refresh_token": self.credentials.refresh_token,
        }

        new_creds = self.post_json_response(
            self.access_token_endpoint,
            data=payload,
            header_overrides={},
        )

        self.credentials.update_access_token(
            new_token=new_creds["access_token"],
            expires_in=new_creds["expires_in"],
            # Monzo
            refresh_token=new_creds.get("refresh_token"),
        )

        self.creds_cache_path.write_text(
            self.credentials.model_dump_json(exclude_none=True),
        )

    def run_first_time_login(self) -> None:
        """Run the first time login process.

        This is a blocking call which will not return until the user has
        authenticated with the OAuth provider.

        Raises:
            RuntimeError: if `use_existing_credentials_only` is set to True
            ValueError: if the state token returned by the OAuth provider does not
                match
        """

        if self.use_existing_credentials_only:
            raise RuntimeError(
                "No existing credentials found, and `use_existing_credentials_only` "
                "is set to True",
            )

        LOGGER.info("Performing first time login")

        state_token = "".join(choice(ascii_letters) for _ in range(32))  # noqa: S311

        self.temp_auth_server.start_server()

        if self.oauth_redirect_uri_override:
            redirect_uri = self.oauth_redirect_uri_override
        else:
            redirect_uri = f"http://{self.oauth_login_redirect_host}:{self.temp_auth_server.port}/get_auth_code"

        auth_link_params = {
            "client_id": self._client_id,
            "redirect_uri": redirect_uri,
            "response_type": "code",
            "state": state_token,
            "access_type": "offline",
            "prompt": "consent",
        }

        if self.scopes:
            auth_link_params["scope"] = " ".join(self.scopes)

        auth_link = self.auth_link_base + "?" + urlencode(auth_link_params)

        if self.HEADLESS_MODE:
            if self.headless_auth_link_callback is None:
                LOGGER.warning(
                    "Headless mode is enabled, but no headless auth link callback "
                    "has been set. The auth link will not be opened.",
                )
                LOGGER.debug("Auth link: %s", auth_link)
            else:
                LOGGER.info("Sending auth link to callback")
                self.headless_auth_link_callback(auth_link)
        else:
            open_browser(auth_link)

        request_args = self.temp_auth_server.wait_for_request(
            "/get_auth_code",
            kill_on_request=True,
        )

        if state_token != request_args.get("state"):
            raise ValueError(
                "State token received in request doesn't match expected value: "
                f"`{request_args.get('state')}` != `{state_token}`",
            )

        payload_key = (
            "data"
            if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
            else "json"
        )

        res = self._post(
            self.access_token_endpoint,
            **{  # type: ignore[arg-type]
                payload_key: {
                    "code": request_args["code"],
                    "grant_type": "authorization_code",
                    "client_id": self._client_id,
                    "client_secret": self._client_secret,
                    "redirect_uri": redirect_uri,
                },
            },
            # Stops recursive call to `self.request_headers`
            header_overrides=(
                {"Content-Type": "application/x-www-form-urlencoded"}
                if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
                else {}
            ),
        )

        credentials = res.json()

        if self._client_id:
            credentials["client_id"] = self._client_id

        if self._client_secret:
            credentials["client_secret"] = self._client_secret

        self.credentials = OAuthCredentials.parse_first_time_login(credentials)

    @property
    def _creds_rel_file_path(self) -> Path | None:
        """Get the credentials cache filepath relative to the cache directory.

        Overridable in subclasses.
        """

        try:
            client_id = self._client_id or self._credentials.client_id
        except AttributeError:
            return None

        return Path(type(self).__name__, f"{client_id}.json")

    @property
    def access_token(self) -> str | None:
        """Access token.

        Returns:
            str: the access token for this bank's API
        """
        if self.access_token_has_expired:
            self.refresh_access_token()

        return self.credentials.access_token

    @property
    def access_token_has_expired(self) -> bool:
        """Decode the JWT access token and evaluates the expiry time.

        Returns:
            bool: has the access token expired?
        """
        if not hasattr(self, "_credentials") and not self._load_local_credentials():
            return True

        return (
            self.credentials.expiry_epoch
            < int(time()) + self.ACCESS_TOKEN_EXPIRY_THRESHOLD
        )

    @property
    def client_id(self) -> str:
        """Client ID for the Google API.

        Returns:
            str: the current client ID
        """

        return self._client_id or self.credentials.client_id

    @property
    def client_secret(self) -> str | None:
        """Client secret.

        Returns:
            str: the current client secret
        """

        return self._client_secret or self.credentials.client_secret

    @property
    def credentials(self) -> OAuthCredentials:
        """Get creds as necessary (including first time setup) and authenticates them.

        Returns:
            OAuthCredentials: the credentials for the chosen bank

        Raises:
            ValueError: if the state token returned from the request doesn't match the
                expected value
        """
        if not hasattr(self, "_credentials") and not self._load_local_credentials():
            self.run_first_time_login()

        return self._credentials

    @credentials.setter
    def credentials(self, value: OAuthCredentials) -> None:
        """Set the client's credentials, and write to the local cache file."""

        self._credentials = value

        self.creds_cache_path.write_text(
            dumps(self._credentials.model_dump(exclude_none=True)),
        )

    @property
    def creds_cache_path(self) -> Path:
        """Path to the credentials cache file.

        Returns:
            Path: the path to the credentials cache file

        Raises:
            ValueError: if the path to the credentials cache file is not set, and can't
                be generated due to a lack of client ID
        """
        if self._creds_cache_path:
            return self._creds_cache_path

        if not self._creds_rel_file_path:
            raise ValueError(
                "Unable to get client ID to generate path for credentials cache file.",
            )

        return force_mkdir(
            (self._creds_cache_dir or user_data_dir() / "oauth_credentials")
            / self._creds_rel_file_path,
            path_is_file=True,
        )

    @property
    def request_headers(self) -> dict[str, str]:
        """Header to be used in requests to the API.

        Returns:
            dict: auth headers for HTTP requests
        """
        return {
            "Authorization": f"Bearer {self.access_token}",
            "Content-Type": "application/json",
        }

    @property
    def refresh_token(self) -> str:
        """Refresh token.

        Returns:
            str: the API refresh token
        """
        return self.credentials.refresh_token

    @property
    def temp_auth_server(self) -> TempAuthServer:
        """Create a temporary HTTP server for the auth flow.

        Returns:
            TempAuthServer: the temporary server
        """
        if not hasattr(self, "_temp_auth_server"):
            self._temp_auth_server = TempAuthServer(__name__, auto_run=False)

        return self._temp_auth_server

access_token: str | None property

Access token.

Returns:

Name Type Description
str str | None

the access token for this bank's API

access_token_has_expired: bool property

Decode the JWT access token and evaluates the expiry time.

Returns:

Name Type Description
bool bool

has the access token expired?

client_id: str property

Client ID for the Google API.

Returns:

Name Type Description
str str

the current client ID

client_secret: str | None property

Client secret.

Returns:

Name Type Description
str str | None

the current client secret

credentials: OAuthCredentials property writable

Get creds as necessary (including first time setup) and authenticates them.

Returns:

Name Type Description
OAuthCredentials OAuthCredentials

the credentials for the chosen bank

Raises:

Type Description
ValueError

if the state token returned from the request doesn't match the expected value

creds_cache_path: Path property

Path to the credentials cache file.

Returns:

Name Type Description
Path Path

the path to the credentials cache file

Raises:

Type Description
ValueError

if the path to the credentials cache file is not set, and can't be generated due to a lack of client ID

refresh_token: str property

Refresh token.

Returns:

Name Type Description
str str

the API refresh token

request_headers: dict[str, str] property

Header to be used in requests to the API.

Returns:

Name Type Description
dict dict[str, str]

auth headers for HTTP requests

temp_auth_server: TempAuthServer property

Create a temporary HTTP server for the auth flow.

Returns:

Name Type Description
TempAuthServer TempAuthServer

the temporary server

__init__(*, client_id=None, client_secret=None, log_requests=False, creds_cache_path=None, creds_cache_dir=None, scopes=None, oauth_login_redirect_host='localhost', oauth_redirect_uri_override=None, headless_auth_link_callback=None, use_existing_credentials_only=False, access_token_endpoint=None, auth_link_base=None, base_url=None, validate_request_success=True)

Initialise the client.

Parameters:

Name Type Description Default
client_id str

the client ID for the API. Defaults to None.

None
client_secret str

the client secret for the API. Defaults to None.

None
log_requests bool

whether to log requests. Defaults to False.

False
creds_cache_path Path

the path to the credentials cache file. Defaults to None. Overrides creds_cache_dir.

None
creds_cache_dir Path

the path to the credentials cache directory. Overrides environment variable WG_UTILITIES_CREDS_CACHE_DIR. Defaults to None.

None
scopes list[str]

the scopes to request when authenticating. Defaults to None.

None
oauth_login_redirect_host str

the host to redirect to after authenticating. Defaults to "localhost".

'localhost'
oauth_redirect_uri_override str

override the redirect URI specified in the OAuth flow. Defaults to None.

None
headless_auth_link_callback Callable[[str], None]

callback to send the auth link to when running in headless mode. Defaults to None.

None
use_existing_credentials_only bool

whether to only use existing credentials, and not attempt to authenticate. Defaults to False.

False
access_token_endpoint str

the endpoint to use to get an access token. Defaults to None.

None
auth_link_base str

the base URL to use to generate the auth link. Defaults to None.

None
base_url str

the base URL to use for API requests. Defaults to None.

None
validate_request_success bool

whether to validate that the request was successful. Defaults to True.

True
Source code in wg_utilities/clients/oauth_client.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def __init__(  # noqa: PLR0913
    self,
    *,
    client_id: str | None = None,
    client_secret: str | None = None,
    log_requests: bool = False,
    creds_cache_path: Path | None = None,
    creds_cache_dir: Path | None = None,
    scopes: list[str] | None = None,
    oauth_login_redirect_host: str = "localhost",
    oauth_redirect_uri_override: str | None = None,
    headless_auth_link_callback: Callable[[str], None] | None = None,
    use_existing_credentials_only: bool = False,
    access_token_endpoint: str | None = None,
    auth_link_base: str | None = None,
    base_url: str | None = None,
    validate_request_success: bool = True,
):
    """Initialise the client.

    Args:
        client_id (str, optional): the client ID for the API. Defaults to None.
        client_secret (str, optional): the client secret for the API. Defaults to
            None.
        log_requests (bool, optional): whether to log requests. Defaults to False.
        creds_cache_path (Path, optional): the path to the credentials cache file.
            Defaults to None. Overrides `creds_cache_dir`.
        creds_cache_dir (Path, optional): the path to the credentials cache directory.
            Overrides environment variable `WG_UTILITIES_CREDS_CACHE_DIR`. Defaults to
            None.
        scopes (list[str], optional): the scopes to request when authenticating.
            Defaults to None.
        oauth_login_redirect_host (str, optional): the host to redirect to after
            authenticating. Defaults to "localhost".
        oauth_redirect_uri_override (str, optional): override the redirect URI
            specified in the OAuth flow. Defaults to None.
        headless_auth_link_callback (Callable[[str], None], optional): callback to
            send the auth link to when running in headless mode. Defaults to None.
        use_existing_credentials_only (bool, optional): whether to only use existing
            credentials, and not attempt to authenticate. Defaults to False.
        access_token_endpoint (str, optional): the endpoint to use to get an access
            token. Defaults to None.
        auth_link_base (str, optional): the base URL to use to generate the auth
            link. Defaults to None.
        base_url (str, optional): the base URL to use for API requests. Defaults to
            None.
        validate_request_success (bool, optional): whether to validate that the
            request was successful. Defaults to True.
    """
    super().__init__(
        log_requests=log_requests,
        base_url=base_url,
        validate_request_success=validate_request_success,
    )

    self._client_id = client_id
    self._client_secret = client_secret
    self.access_token_endpoint = access_token_endpoint or self.ACCESS_TOKEN_ENDPOINT
    self.auth_link_base = auth_link_base or self.AUTH_LINK_BASE
    self.oauth_login_redirect_host = oauth_login_redirect_host
    self.oauth_redirect_uri_override = oauth_redirect_uri_override
    self.headless_auth_link_callback = headless_auth_link_callback
    self.use_existing_credentials_only = use_existing_credentials_only

    if creds_cache_path:
        self._creds_cache_path: Path | None = creds_cache_path
        self._creds_cache_dir: Path | None = None
    elif creds_cache_dir:
        self._creds_cache_path = None
        self._creds_cache_dir = creds_cache_dir
    else:
        self._creds_cache_path = None
        if self.DEFAULT_CACHE_DIR:
            self._creds_cache_dir = Path(self.DEFAULT_CACHE_DIR)
        else:
            self._creds_cache_dir = None

    self.scopes = scopes or self.DEFAULT_SCOPES

    if self._creds_cache_path:
        self._load_local_credentials()

delete_creds_file()

Delete the local creds file.

Source code in wg_utilities/clients/oauth_client.py
357
358
359
def delete_creds_file(self) -> None:
    """Delete the local creds file."""
    self.creds_cache_path.unlink(missing_ok=True)

refresh_access_token()

Refresh access token.

Source code in wg_utilities/clients/oauth_client.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
def refresh_access_token(self) -> None:
    """Refresh access token."""

    if not hasattr(self, "_credentials") and not self._load_local_credentials():
        # If we don't have any credentials, we can't refresh the access token -
        # perform first time login and leave it at that
        self.run_first_time_login()
        return

    LOGGER.info("Refreshing access token")

    payload = {
        "grant_type": "refresh_token",
        "client_id": self.client_id,
        "client_secret": self.client_secret,
        "refresh_token": self.credentials.refresh_token,
    }

    new_creds = self.post_json_response(
        self.access_token_endpoint,
        data=payload,
        header_overrides={},
    )

    self.credentials.update_access_token(
        new_token=new_creds["access_token"],
        expires_in=new_creds["expires_in"],
        # Monzo
        refresh_token=new_creds.get("refresh_token"),
    )

    self.creds_cache_path.write_text(
        self.credentials.model_dump_json(exclude_none=True),
    )

run_first_time_login()

Run the first time login process.

This is a blocking call which will not return until the user has authenticated with the OAuth provider.

Raises:

Type Description
RuntimeError

if use_existing_credentials_only is set to True

ValueError

if the state token returned by the OAuth provider does not match

Source code in wg_utilities/clients/oauth_client.py
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
def run_first_time_login(self) -> None:
    """Run the first time login process.

    This is a blocking call which will not return until the user has
    authenticated with the OAuth provider.

    Raises:
        RuntimeError: if `use_existing_credentials_only` is set to True
        ValueError: if the state token returned by the OAuth provider does not
            match
    """

    if self.use_existing_credentials_only:
        raise RuntimeError(
            "No existing credentials found, and `use_existing_credentials_only` "
            "is set to True",
        )

    LOGGER.info("Performing first time login")

    state_token = "".join(choice(ascii_letters) for _ in range(32))  # noqa: S311

    self.temp_auth_server.start_server()

    if self.oauth_redirect_uri_override:
        redirect_uri = self.oauth_redirect_uri_override
    else:
        redirect_uri = f"http://{self.oauth_login_redirect_host}:{self.temp_auth_server.port}/get_auth_code"

    auth_link_params = {
        "client_id": self._client_id,
        "redirect_uri": redirect_uri,
        "response_type": "code",
        "state": state_token,
        "access_type": "offline",
        "prompt": "consent",
    }

    if self.scopes:
        auth_link_params["scope"] = " ".join(self.scopes)

    auth_link = self.auth_link_base + "?" + urlencode(auth_link_params)

    if self.HEADLESS_MODE:
        if self.headless_auth_link_callback is None:
            LOGGER.warning(
                "Headless mode is enabled, but no headless auth link callback "
                "has been set. The auth link will not be opened.",
            )
            LOGGER.debug("Auth link: %s", auth_link)
        else:
            LOGGER.info("Sending auth link to callback")
            self.headless_auth_link_callback(auth_link)
    else:
        open_browser(auth_link)

    request_args = self.temp_auth_server.wait_for_request(
        "/get_auth_code",
        kill_on_request=True,
    )

    if state_token != request_args.get("state"):
        raise ValueError(
            "State token received in request doesn't match expected value: "
            f"`{request_args.get('state')}` != `{state_token}`",
        )

    payload_key = (
        "data"
        if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
        else "json"
    )

    res = self._post(
        self.access_token_endpoint,
        **{  # type: ignore[arg-type]
            payload_key: {
                "code": request_args["code"],
                "grant_type": "authorization_code",
                "client_id": self._client_id,
                "client_secret": self._client_secret,
                "redirect_uri": redirect_uri,
            },
        },
        # Stops recursive call to `self.request_headers`
        header_overrides=(
            {"Content-Type": "application/x-www-form-urlencoded"}
            if self.__class__.__name__ in ("MonzoClient", "SpotifyClient")
            else {}
        ),
    )

    credentials = res.json()

    if self._client_id:
        credentials["client_id"] = self._client_id

    if self._client_secret:
        credentials["client_secret"] = self._client_secret

    self.credentials = OAuthCredentials.parse_first_time_login(credentials)

OAuthCredentials

Bases: BaseModelWithConfig

Typing info for OAuth credentials.

Source code in wg_utilities/clients/oauth_client.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class OAuthCredentials(BaseModelWithConfig):
    """Typing info for OAuth credentials."""

    access_token: str
    client_id: str
    expiry_epoch: float
    refresh_token: str
    scope: str
    token_type: Literal["Bearer"]
    client_secret: str

    # Monzo
    user_id: str | None = None

    # Google
    token: str | None = None
    token_uri: str | None = None
    scopes: list[str] | None = None

    @classmethod
    def parse_first_time_login(cls, value: dict[str, Any]) -> OAuthCredentials:
        """Parse the response from a first time login into a credentials object.

        The following fields are returned per API:
        +---------------+--------+-------+---------+-----------+
        |               | Google | Monzo | Spotify | TrueLayer |
        +===============+========+=======+=========+===========+
        | access_token  |    X   |   X   |    X    |     X     |
        | client_id     |    X   |   X   |    X    |     X     |
        | expiry_epoch  |    X   |   X   |    X    |     X     |
        | refresh_token |    X   |   X   |    X    |     X     |
        | scope         |    X   |   X   |    X    |     X     |
        | token_type    |    X   |   X   |    X    |     X     |
        | client_secret |    X   |       |    X    |           |
        | user_id       |        |   X   |         |           |
        | token         |    X   |       |         |           |
        | token_uri     |    X   |       |         |           |
        | scopes        |    X   |       |         |           |
        +---------------+--------+-------+---------+-----------+

        Args:
            value: the response from the API

        Returns:
            OAuthCredentials: an OAuthCredentials instance

        Raises:
            ValueError: if `expiry` and `expiry_epoch` aren't the same
        """

        # Calculate the expiry time of the access token
        try:
            # Try to decode it if it's a valid JWT (with expiry)
            expiry_epoch = decode(
                value["access_token"],
                options={"verify_signature": False},
            )["exp"]
            value.pop("expires_in", None)
        except (DecodeError, KeyError):
            # If that's not possible, calculate it from the expires_in value
            expires_in = value.pop("expires_in")

            # Subtract 2.5 seconds to account for latency
            expiry_epoch = time() + expires_in - 2.5

            # Verify it against the expiry time string
            if expiry_time_str := value.get("expiry"):
                expiry_time = datetime.fromisoformat(expiry_time_str)
                if abs(expiry_epoch - expiry_time.timestamp()) > 60:  # noqa: PLR2004
                    raise ValueError(
                        "`expiry` and `expires_in` are not consistent with each other:"
                        f" expiry: {expiry_time_str}, expires_in: {expiry_epoch}",
                    ) from None

        value["expiry_epoch"] = expiry_epoch

        return cls(**value)

    def update_access_token(
        self,
        new_token: str,
        expires_in: int,
        refresh_token: str | None = None,
    ) -> None:
        """Update the access token and expiry time.

        Args:
            new_token (str): the newly refreshed access token
            expires_in (int): the number of seconds until the token expires
            refresh_token (str, optional): a new refresh token. Defaults to unset.
        """
        self.access_token = new_token
        self.expiry_epoch = time() + expires_in - 2.5

        if refresh_token is not None:
            self.refresh_token = refresh_token

    @property
    def is_expired(self) -> bool:
        """Check if the access token is expired.

        Returns:
            bool: True if the token is expired, False otherwise
        """
        return self.expiry_epoch < time()

is_expired: bool property

Check if the access token is expired.

Returns:

Name Type Description
bool bool

True if the token is expired, False otherwise

parse_first_time_login(value) classmethod

Parse the response from a first time login into a credentials object.

The following fields are returned per API: +---------------+--------+-------+---------+-----------+ | | Google | Monzo | Spotify | TrueLayer | +===============+========+=======+=========+===========+ | access_token | X | X | X | X | | client_id | X | X | X | X | | expiry_epoch | X | X | X | X | | refresh_token | X | X | X | X | | scope | X | X | X | X | | token_type | X | X | X | X | | client_secret | X | | X | | | user_id | | X | | | | token | X | | | | | token_uri | X | | | | | scopes | X | | | | +---------------+--------+-------+---------+-----------+

Parameters:

Name Type Description Default
value dict[str, Any]

the response from the API

required

Returns:

Name Type Description
OAuthCredentials OAuthCredentials

an OAuthCredentials instance

Raises:

Type Description
ValueError

if expiry and expiry_epoch aren't the same

Source code in wg_utilities/clients/oauth_client.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
@classmethod
def parse_first_time_login(cls, value: dict[str, Any]) -> OAuthCredentials:
    """Parse the response from a first time login into a credentials object.

    The following fields are returned per API:
    +---------------+--------+-------+---------+-----------+
    |               | Google | Monzo | Spotify | TrueLayer |
    +===============+========+=======+=========+===========+
    | access_token  |    X   |   X   |    X    |     X     |
    | client_id     |    X   |   X   |    X    |     X     |
    | expiry_epoch  |    X   |   X   |    X    |     X     |
    | refresh_token |    X   |   X   |    X    |     X     |
    | scope         |    X   |   X   |    X    |     X     |
    | token_type    |    X   |   X   |    X    |     X     |
    | client_secret |    X   |       |    X    |           |
    | user_id       |        |   X   |         |           |
    | token         |    X   |       |         |           |
    | token_uri     |    X   |       |         |           |
    | scopes        |    X   |       |         |           |
    +---------------+--------+-------+---------+-----------+

    Args:
        value: the response from the API

    Returns:
        OAuthCredentials: an OAuthCredentials instance

    Raises:
        ValueError: if `expiry` and `expiry_epoch` aren't the same
    """

    # Calculate the expiry time of the access token
    try:
        # Try to decode it if it's a valid JWT (with expiry)
        expiry_epoch = decode(
            value["access_token"],
            options={"verify_signature": False},
        )["exp"]
        value.pop("expires_in", None)
    except (DecodeError, KeyError):
        # If that's not possible, calculate it from the expires_in value
        expires_in = value.pop("expires_in")

        # Subtract 2.5 seconds to account for latency
        expiry_epoch = time() + expires_in - 2.5

        # Verify it against the expiry time string
        if expiry_time_str := value.get("expiry"):
            expiry_time = datetime.fromisoformat(expiry_time_str)
            if abs(expiry_epoch - expiry_time.timestamp()) > 60:  # noqa: PLR2004
                raise ValueError(
                    "`expiry` and `expires_in` are not consistent with each other:"
                    f" expiry: {expiry_time_str}, expires_in: {expiry_epoch}",
                ) from None

    value["expiry_epoch"] = expiry_epoch

    return cls(**value)

update_access_token(new_token, expires_in, refresh_token=None)

Update the access token and expiry time.

Parameters:

Name Type Description Default
new_token str

the newly refreshed access token

required
expires_in int

the number of seconds until the token expires

required
refresh_token str

a new refresh token. Defaults to unset.

None
Source code in wg_utilities/clients/oauth_client.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def update_access_token(
    self,
    new_token: str,
    expires_in: int,
    refresh_token: str | None = None,
) -> None:
    """Update the access token and expiry time.

    Args:
        new_token (str): the newly refreshed access token
        expires_in (int): the number of seconds until the token expires
        refresh_token (str, optional): a new refresh token. Defaults to unset.
    """
    self.access_token = new_token
    self.expiry_epoch = time() + expires_in - 2.5

    if refresh_token is not None:
        self.refresh_token = refresh_token

spotify

Custom client for interacting with Spotify's Web API.

Album

Bases: SpotifyEntity[AlbumSummaryJson]

An album on Spotify.

Source code in wg_utilities/clients/spotify.py
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
class Album(SpotifyEntity[AlbumSummaryJson]):
    """An album on Spotify."""

    album_group: Literal["album", "single", "compilation", "appears_on"] | None = None
    album_type_str: Literal[
        "single",
        "album",
        "compilation",
        "SINGLE",
        "ALBUM",
        "COMPILATION",
    ] = Field(alias="album_type")
    artists_json: list[ArtistSummaryJson] = Field(alias="artists")
    available_markets: list[str]
    copyrights: list[dict[str, str]] | None = None
    external_ids: dict[str, str] | None = None
    genres: list[str] | None = None
    images: list[Image]
    is_playable: bool | None = None
    label: str | None = None
    popularity: int | None = None
    release_date_precision: Literal["year", "month", "day"] | None = None
    release_date: date
    restrictions: dict[str, str] | None = None
    total_tracks: int
    tracks_json: PaginatedResponseTracks = Field(alias="tracks", default_factory=dict)  # type: ignore[assignment]
    type: Literal["album"]

    _artists: list[Artist]
    _tracks: list[Track]

    sj_type: ClassVar[SpotifyEntityJsonType] = AlbumSummaryJson

    @field_validator("release_date", mode="before")
    @classmethod
    def validate_release_date(cls, value: str | date, info: ValidationInfo) -> date:
        """Convert the release date string to a date object."""

        if isinstance(value, date):
            return value

        rdp = (info.data.get("release_date_precision") or "day").lower()

        exception = ValueError(
            f"Incompatible release_date and release_date_precision values: {value!r}"
            f" and {rdp!r} respectively.",
        )

        match value.split("-"):
            case y, m, d:
                if rdp != "day":
                    raise exception
                return date(int(y), int(m), int(d))
            case y, m:
                if rdp != "month":
                    raise exception
                return date(int(y), int(m), 1)
            case (y,):
                if rdp != "year":
                    raise exception
                return date(int(y), 1, 1)
            case _:
                raise exception

    @property
    def album_type(self) -> AlbumType:
        """Convert the album type string to an enum value."""

        return AlbumType(self.album_type_str.lower())

    @property
    def artists(self) -> list[Artist]:
        """Return a list of artists who contributed to the track.

        Returns:
            list(Artist): a list of the artists who contributed to this track
        """

        if not hasattr(self, "_artists"):
            artists = [
                Artist.from_json_response(
                    item,
                    spotify_client=self.spotify_client,
                )
                for item in self.artists_json
            ]

            self._artists = artists

        return self._artists

    @property
    def tracks(self) -> list[Track]:
        """List of tracks on the album.

        Returns:
            list: a list of tracks on this album
        """

        if not hasattr(self, "_tracks"):
            if self.tracks_json:
                # Initialise the list with data from the album JSON...
                tracks = [
                    Track.from_json_response(
                        item,
                        spotify_client=self.spotify_client,
                        additional_fields={"album": self.summary_json},
                    )
                    for item in self.tracks_json["items"]
                ]

                # ...then add the rest of the tracks from the API if necessary.
                if next_url := self.tracks_json.get("next"):
                    tracks.extend(
                        [
                            Track.from_json_response(
                                item,
                                spotify_client=self.spotify_client,
                                additional_fields={"album": self.summary_json},
                            )
                            for item in self.spotify_client.get_items(next_url)
                        ],
                    )
            else:
                tracks = [
                    Track.from_json_response(
                        item,
                        spotify_client=self.spotify_client,
                        additional_fields={"album": self.summary_json},
                    )
                    for item in self.spotify_client.get_items(f"/albums/{self.id}/tracks")
                ]

            self._tracks = tracks

        return self._tracks

album_type: AlbumType property

Convert the album type string to an enum value.

artists: list[Artist] property

Return a list of artists who contributed to the track.

Returns:

Name Type Description
list Artist

a list of the artists who contributed to this track

tracks: list[Track] property

List of tracks on the album.

Returns:

Name Type Description
list list[Track]

a list of tracks on this album

validate_release_date(value, info) classmethod

Convert the release date string to a date object.

Source code in wg_utilities/clients/spotify.py
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
@field_validator("release_date", mode="before")
@classmethod
def validate_release_date(cls, value: str | date, info: ValidationInfo) -> date:
    """Convert the release date string to a date object."""

    if isinstance(value, date):
        return value

    rdp = (info.data.get("release_date_precision") or "day").lower()

    exception = ValueError(
        f"Incompatible release_date and release_date_precision values: {value!r}"
        f" and {rdp!r} respectively.",
    )

    match value.split("-"):
        case y, m, d:
            if rdp != "day":
                raise exception
            return date(int(y), int(m), int(d))
        case y, m:
            if rdp != "month":
                raise exception
            return date(int(y), int(m), 1)
        case (y,):
            if rdp != "year":
                raise exception
            return date(int(y), 1, 1)
        case _:
            raise exception

AlbumType

Bases: StrEnum

Enum for the different types of album Spotify supports.

Source code in wg_utilities/clients/spotify.py
71
72
73
74
75
76
class AlbumType(StrEnum):
    """Enum for the different types of album Spotify supports."""

    SINGLE = "single"
    ALBUM = "album"
    COMPILATION = "compilation"

Artist

Bases: SpotifyEntity[ArtistSummaryJson]

An artist on Spotify.

Source code in wg_utilities/clients/spotify.py
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
class Artist(SpotifyEntity[ArtistSummaryJson]):
    """An artist on Spotify."""

    followers: Followers | None = None
    genres: list[str] | None = None
    images: list[Image] | None = None
    popularity: int | None = None
    type: Literal["artist"]

    _albums: list[Album]

    sj_type: ClassVar[SpotifyEntityJsonType] = ArtistSummaryJson

    @property
    def albums(self) -> list[Album]:
        """Return a list of albums by this artist.

        Returns:
            list: A list of albums this artist has contributed to
        """
        if not hasattr(self, "_albums"):
            albums = [
                Album.from_json_response(item, spotify_client=self.spotify_client)
                for item in self.spotify_client.get_items(f"/artists/{self.id}/albums")
            ]

            self._albums = albums

        return self._albums

albums: list[Album] property

Return a list of albums by this artist.

Returns:

Name Type Description
list list[Album]

A list of albums this artist has contributed to

Device

Bases: BaseModelWithConfig

Model for a Spotify device.

Source code in wg_utilities/clients/spotify.py
79
80
81
82
83
84
85
86
87
88
class Device(BaseModelWithConfig):
    """Model for a Spotify device."""

    id: str
    is_active: bool
    is_private_session: bool
    is_restricted: bool
    name: str
    type: str
    volume_percent: int

ParsedSearchResponse

Bases: TypedDict

The return type of SpotifyClient.search.

Source code in wg_utilities/clients/spotify.py
62
63
64
65
66
67
68
class ParsedSearchResponse(TypedDict):
    """The return type of `SpotifyClient.search`."""

    albums: NotRequired[list[Album]]
    artists: NotRequired[list[Artist]]
    playlists: NotRequired[list[Playlist]]
    tracks: NotRequired[list[Track]]

Playlist

Bases: SpotifyEntity[PlaylistSummaryJson]

A Spotify playlist.

Source code in wg_utilities/clients/spotify.py
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
class Playlist(SpotifyEntity[PlaylistSummaryJson]):
    """A Spotify playlist."""

    collaborative: bool
    followers: Followers | None = None
    images: list[Image] | None = None
    owner_json: UserSummaryJson = Field(alias="owner")
    primary_color: str | None = None
    # TODO: the `None` cases here can be handled by defaulting `public` to
    #  `self.owner.id == <user ID>`
    public: bool | None = None
    snapshot_id: str
    tracks_json: PaginatedResponsePlaylistTracks | PlaylistSummaryJsonTracks = Field(
        alias="tracks",
    )
    type: Literal["playlist"]

    _tracks: list[Track]
    _owner: User

    sj_type: ClassVar[SpotifyEntityJsonType] = PlaylistSummaryJson
    _live_snapshot_id_timestamp: datetime
    _live_snapshot_id: str

    @field_validator("tracks_json", mode="before")
    @classmethod
    def remove_local_tracks(
        cls,
        tracks_json: PaginatedResponsePlaylistTracks,
    ) -> PaginatedResponsePlaylistTracks:
        """Remove local tracks from the playlist's tracklist."""

        if "items" in tracks_json:
            tracks_json["items"] = [
                item for item in tracks_json["items"] if not item["is_local"]
            ]

        return tracks_json

    @property
    def live_snapshot_id(self) -> str:
        """The live snapshot ID of the playlist.

        The value is cached for a minute before being refreshed.

        Returns:
            str: the live snapshot ID of the playlist
        """
        if (
            not hasattr(self, "_live_snapshot_id_timestamp")
            or not hasattr(self, "_live_snapshot_id")
            or datetime.now(UTC) - self._live_snapshot_id_timestamp > timedelta(minutes=1)
        ):
            self._live_snapshot_id = self.spotify_client.get_json_response(
                f"/playlists/{self.id}",
                params={"fields": "snapshot_id"},
            )[
                "snapshot_id"  # type: ignore[typeddict-item]
            ]

            self._live_snapshot_id_timestamp = datetime.now(UTC)

        return self._live_snapshot_id

    @property
    def owner(self) -> User:
        """Playlist owner.

        Returns:
            User: the Spotify user who owns this playlist
        """

        if not hasattr(self, "_owner"):
            self._owner = User.from_json_response(
                self.owner_json,
                spotify_client=self.spotify_client,
            )

        return self._owner

    @property
    def tracks(self) -> list[Track]:
        """Return a list of tracks in the playlist.

        Returns:
            list: a list of tracks in this playlist
        """

        if not hasattr(self, "_tracks") or self.updates_available:
            tracks = [
                Track.from_json_response(
                    item["track"],
                    spotify_client=self.spotify_client,
                )
                for item in cast(
                    list[PlaylistFullJsonTracks],
                    self.spotify_client.get_items(f"/playlists/{self.id}/tracks"),
                )
                if item.get("track") is not None and item["is_local"] is False
            ]

            self._tracks = tracks

            if hasattr(self, "_live_snapshot_id"):
                self.snapshot_id = self._live_snapshot_id
            else:
                self._live_snapshot_id = self.snapshot_id

        return self._tracks

    @property
    def updates_available(self) -> bool:
        """Check if the playlist has updates available.

        Returns:
            bool: whether the playlist has updates available
        """
        return self.live_snapshot_id != self.snapshot_id

    def __contains__(self, track: Track) -> bool:
        """Check if a track is in the playlist."""
        return track in self.tracks

    def __gt__(self, other: object) -> bool:
        """Compare two playlists by name and ID."""
        if not isinstance(other, Playlist):
            return NotImplemented

        if self == other:
            return False

        return (self.name.lower(), self.id.lower()) > (
            other.name.lower(),
            other.id.lower(),
        )

    def __iter__(self) -> Iterator[Track]:  # type: ignore[override]
        """Iterate over the tracks in the playlist."""
        return iter(self.tracks)

    def __lt__(self, other: object) -> bool:
        """Compare two playlists by name and ID."""
        if not isinstance(other, Playlist):
            return NotImplemented

        if self == other:
            return False

        return (self.name.lower(), self.id.lower()) < (
            other.name.lower(),
            other.id.lower(),
        )

live_snapshot_id: str property

The live snapshot ID of the playlist.

The value is cached for a minute before being refreshed.

Returns:

Name Type Description
str str

the live snapshot ID of the playlist

owner: User property

Playlist owner.

Returns:

Name Type Description
User User

the Spotify user who owns this playlist

tracks: list[Track] property

Return a list of tracks in the playlist.

Returns:

Name Type Description
list list[Track]

a list of tracks in this playlist

updates_available: bool property

Check if the playlist has updates available.

Returns:

Name Type Description
bool bool

whether the playlist has updates available

__contains__(track)

Check if a track is in the playlist.

Source code in wg_utilities/clients/spotify.py
1084
1085
1086
def __contains__(self, track: Track) -> bool:
    """Check if a track is in the playlist."""
    return track in self.tracks

__gt__(other)

Compare two playlists by name and ID.

Source code in wg_utilities/clients/spotify.py
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
def __gt__(self, other: object) -> bool:
    """Compare two playlists by name and ID."""
    if not isinstance(other, Playlist):
        return NotImplemented

    if self == other:
        return False

    return (self.name.lower(), self.id.lower()) > (
        other.name.lower(),
        other.id.lower(),
    )

__iter__()

Iterate over the tracks in the playlist.

Source code in wg_utilities/clients/spotify.py
1101
1102
1103
def __iter__(self) -> Iterator[Track]:  # type: ignore[override]
    """Iterate over the tracks in the playlist."""
    return iter(self.tracks)

__lt__(other)

Compare two playlists by name and ID.

Source code in wg_utilities/clients/spotify.py
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
def __lt__(self, other: object) -> bool:
    """Compare two playlists by name and ID."""
    if not isinstance(other, Playlist):
        return NotImplemented

    if self == other:
        return False

    return (self.name.lower(), self.id.lower()) < (
        other.name.lower(),
        other.id.lower(),
    )

remove_local_tracks(tracks_json) classmethod

Remove local tracks from the playlist's tracklist.

Source code in wg_utilities/clients/spotify.py
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
@field_validator("tracks_json", mode="before")
@classmethod
def remove_local_tracks(
    cls,
    tracks_json: PaginatedResponsePlaylistTracks,
) -> PaginatedResponsePlaylistTracks:
    """Remove local tracks from the playlist's tracklist."""

    if "items" in tracks_json:
        tracks_json["items"] = [
            item for item in tracks_json["items"] if not item["is_local"]
        ]

    return tracks_json

SpotifyClient

Bases: OAuthClient[SpotifyEntityJson]

Custom client for interacting with Spotify's Web API.

For authentication purposes either an already-instantiated OAuth manager or the relevant credentials must be provided

Parameters:

Name Type Description Default
client_id str

the application's client ID

None
client_secret str

the application's client secret

None
log_requests bool

flag for choosing if to log all requests made

False
creds_cache_path str

path at which to save cached credentials

None
Source code in wg_utilities/clients/spotify.py
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
class SpotifyClient(OAuthClient[SpotifyEntityJson]):
    """Custom client for interacting with Spotify's Web API.

    For authentication purposes either an already-instantiated OAuth manager or the
    relevant credentials must be provided

    Args:
        client_id (str): the application's client ID
        client_secret (str): the application's client secret
        log_requests (bool): flag for choosing if to log all requests made
        creds_cache_path (str): path at which to save cached credentials
    """

    AUTH_LINK_BASE = "https://accounts.spotify.com/authorize"
    ACCESS_TOKEN_ENDPOINT = "https://accounts.spotify.com/api/token"  # noqa: S105
    BASE_URL = "https://api.spotify.com/v1"

    DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "ugc-image-upload",
        "user-read-recently-played",
        "user-top-read",
        "user-read-playback-position",
        "user-read-playback-state",
        "user-modify-playback-state",
        "user-read-currently-playing",
        "app-remote-control",
        "streaming",
        "playlist-modify-public",
        "playlist-modify-private",
        "playlist-read-private",
        "playlist-read-collaborative",
        "user-follow-modify",
        "user-follow-read",
        "user-library-modify",
        "user-library-read",
        "user-read-email",
        "user-read-private",
    ]

    SEARCH_TYPES: tuple[Literal["album", "artist", "playlist", "track"], ...] = (
        "album",
        "artist",
        "playlist",
        "track",
        # "show",
        # "episode",
    )

    _current_user: User

    def add_tracks_to_playlist(
        self,
        tracks: Iterable[Track],
        playlist: Playlist,
        *,
        log_responses: bool = False,
        force_add: bool = False,
        update_instance_tracklist: bool = True,
    ) -> list[Track]:
        """Add one or more tracks to a playlist.

        If `force_add` is False, a check is made against the Playlist's tracklist to
        ensure that the track isn't already in the playlist. This can be slow if it's
        a (new) big playlist.

        Args:
            tracks (list): a list of Track instances to be added to the given playlist
            playlist (Playlist): the playlist being updated
            log_responses (bool): log each individual response
            force_add (bool): flag for adding the track even if it's in the playlist
                already
            update_instance_tracklist (bool): appends the track to the Playlist's
                tracklist. Can be slow if it's a big playlist as it might have to get
                the tracklist first
        """

        tracks_to_add = [
            track
            for track in tracks
            if not track.is_local and (force_add or track not in playlist)
        ]

        for chunk in chunk_list(tracks_to_add, 100):
            res = self._post(
                f"/playlists/{playlist.id}/tracks",
                json={"uris": [t.uri for t in chunk]},
            )

            if log_responses:
                LOGGER.info(dumps(res.json()))

        if update_instance_tracklist:
            playlist.tracks.extend(tracks_to_add)

        return tracks_to_add

    def create_playlist(
        self,
        *,
        name: str,
        description: str = "",
        public: bool = False,
        collaborative: bool = False,
    ) -> Playlist:
        """Create a new playlist under the current user's account.

        Args:
            name (str): the name of the new playlist
            description (str): the description of the new playlist
            public (bool): determines if the playlist is publicly accessible
            collaborative (bool): allows other people to add tracks when True

        Returns:
            Playlist: an instance of the Playlist class containing the new playlist's
                metadata
        """
        res = self._post(
            f"/users/{self.current_user.id}/playlists",
            json={
                "name": name,
                "description": description,
                "public": public,
                "collaborative": collaborative,
            },
        )

        return Playlist.from_json_response(res.json(), spotify_client=self)

    def get_album_by_id(self, id_: str) -> Album:
        """Get an album from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the album

        Returns:
            Album: an instantiated Album, from the API's response
        """

        return Album.from_json_response(
            self.get_json_response(f"/albums/{id_}"),
            spotify_client=self,
        )

    def get_artist_by_id(self, id_: str) -> Artist:
        """Get an artist from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the artist

        Returns:
            Artist: an instantiated Artist, from the API's response
        """

        return Artist.from_json_response(
            self.get_json_response(f"/artists/{id_}"),
            spotify_client=self,
        )

    def get_items(
        self,
        url: str,
        *,
        params: None | dict[str, str | int | float | bool | dict[str, Any]] = None,
        hard_limit: int = 1000000,
        limit_func: (
            Callable[
                [dict[str, Any] | SpotifyEntityJson],
                bool,
            ]
            | None
        ) = None,
        top_level_key: (
            Literal[
                "albums",
                "artists",
                "audiobooks",
                "episodes",
                "playlists",
                "shows",
                "tracks",
            ]
            | None
        ) = None,
        list_key: Literal["items", "devices"] = "items",
    ) -> list[SpotifyEntityJson]:
        """Retrieve a list of items from a given URL, including pagination.

        Args:
            url (str): the API endpoint which we're listing
            params (dict): any params to pass with the API request
            hard_limit (int): a hard limit to apply to the number of items returned (as
                opposed to the "soft" limit of 50 imposed by the API)
            limit_func (Callable): a function which is used to evaluate each item in
                turn: if it returns False, the item is added to the output list; if it
                returns True then the iteration stops and the list is returned as-is
            top_level_key (str): an optional key to use when the items in the response
                are nested 1 level deeper than normal
            list_key (Literal["devices", "items"]): the key in the response which
                contains the list of items

        Returns:
            list: a list of dicts representing the Spotify items
        """

        params = params or {}
        if "limit=" not in url:
            params["limit"] = min(50, hard_limit)

        items: list[SpotifyEntityJson] = []

        if params:
            url += ("?" if "?" not in url else "&") + urlencode(params)

        page: AnyPaginatedResponse = {
            "href": "",
            "items": [],
            "limit": 0,
            "next": url,
            "offset": 0,
            "total": 0,
        }

        while (next_url := page.get("next")) and len(items) < hard_limit:
            # Ensure we don't bother getting more items than we need
            limit = min(50, hard_limit - len(items))
            next_url = sub(r"(?<=limit=)(\d{1,2})(?=&?)", str(limit), next_url)

            res: SearchResponse | AnyPaginatedResponse = self.get_json_response(next_url)  # type: ignore[assignment]
            page = (
                cast(SearchResponse, res)[top_level_key]
                if top_level_key
                else cast(AnyPaginatedResponse, res)
            )

            page_items: (
                list[AlbumSummaryJson]
                | list[DeviceJson]
                | list[ArtistSummaryJson]
                | list[PlaylistSummaryJson]
                | list[TrackFullJson]
            ) = page.get(list_key, [])
            if limit_func is None:
                items.extend(page_items)
            else:
                # Initialise `limit_reached` to False, and then it will be set to
                # True on the first matching item. This will then cause the loop to
                # skip subsequent items - not as good as a `break` but still kind of
                # elegant imho!
                limit_reached = False
                items.extend(
                    [
                        item
                        for item in page_items
                        if (not (limit_reached := (limit_reached or limit_func(item))))
                    ],
                )
                if limit_reached:
                    return items

        return items

    def get_playlist_by_id(self, id_: str) -> Playlist:
        """Get a playlist from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the playlist

        Returns:
            Playlist: an instantiated Playlist, from the API's response
        """

        if hasattr(self, "_current_user") and hasattr(self.current_user, "_playlists"):
            for playlist in self.current_user.playlists:
                if playlist.id == id_:
                    return playlist

        return Playlist.from_json_response(
            self.get_json_response(f"/playlists/{id_}"),
            spotify_client=self,
        )

    def get_track_by_id(self, id_: str) -> Track:
        """Get a track from Spotify based on the ID.

        Args:
            id_ (str): the Spotify ID which is used to find the track

        Returns:
            Track: an instantiated Track, from the API's response
        """

        return Track.from_json_response(
            self.get_json_response(f"/tracks/{id_}"),
            spotify_client=self,
        )

    @overload
    def search(
        self,
        search_term: str,
        *,
        entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
        get_best_match_only: Literal[True],
    ) -> Artist | Playlist | Track | Album | None: ...

    @overload
    def search(
        self,
        search_term: str,
        *,
        entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
        get_best_match_only: Literal[False] = False,
    ) -> ParsedSearchResponse: ...

    def search(
        self,
        search_term: str,
        *,
        entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
        get_best_match_only: bool = False,
    ) -> Artist | Playlist | Track | Album | None | ParsedSearchResponse:
        """Search Spotify for a given search term.

        Args:
            search_term (str): the term to use as the base of the search
            entity_types (str): the types of entity to search for. Must be one of
                SpotifyClient.SEARCH_TYPES
            get_best_match_only (bool): return a single entity from the top of the
                list, rather than all matches

        Returns:
            Artist | Playlist | Track | Album: a single entity if the best match flag
                is set
            dict: a dict of entities, by type

        Raises:
            ValueError: if multiple entity types have been requested but the best match
                flag is true
            ValueError: if one of entity_types is an invalid value
        """

        entity_types = entity_types or self.SEARCH_TYPES

        if get_best_match_only is True and len(entity_types) != 1:
            raise ValueError(
                "Exactly one entity type must be requested if `get_best_match_only`"
                " is True",
            )

        entity_type: Literal["artist", "playlist", "track", "album"]
        for entity_type in entity_types:
            if entity_type not in self.SEARCH_TYPES:
                raise ValueError(
                    f"Unexpected value for entity type: '{entity_type}'. Must be"
                    f" one of {self.SEARCH_TYPES!r}",
                )

        res: SearchResponse = self.get_json_response(  # type: ignore[assignment]
            "/search",
            params={
                "query": search_term,
                "type": ",".join(entity_types),
                "limit": 1 if get_best_match_only else 50,
            },
        )

        entity_instances: ParsedSearchResponse = {}

        res_entity_type: Literal["albums", "artists", "playlists", "tracks"]
        entities_json: (
            PaginatedResponseAlbums
            | PaginatedResponseArtists
            | PaginatedResponsePlaylists
            | PaginatedResponseTracks
        )
        for res_entity_type, entities_json in res.items():  # type: ignore[assignment]
            instance_class: type[Album] | type[Artist] | type[Playlist] | type[Track] = {  # type: ignore[assignment]
                "albums": Album,
                "artists": Artist,
                "playlists": Playlist,
                "tracks": Track,
            }[res_entity_type]

            if get_best_match_only:
                try:
                    # Take the entity off the top of the list
                    return instance_class.from_json_response(
                        entities_json["items"][0],
                        spotify_client=self,
                    )
                except LookupError:
                    return None

            entity_instances.setdefault(res_entity_type, []).extend(
                [
                    instance_class.from_json_response(entity_json, spotify_client=self)  # type: ignore[misc]
                    for entity_json in entities_json.get("items", [])
                ],
            )

            # Each entity type has its own type-specific next URL
            if (next_url := entities_json.get("next")) is not None:
                entity_instances[res_entity_type].extend(
                    [
                        instance_class.from_json_response(  # type: ignore[misc]
                            item,
                            spotify_client=self,
                        )
                        for item in self.get_items(
                            next_url,
                            top_level_key=res_entity_type,
                        )
                    ],
                )

        return entity_instances

    @property
    def current_user(self) -> User:
        """Get the current user's info.

        Returns:
            User: an instance of the current Spotify user
        """
        if not hasattr(self, "_current_user"):
            self._current_user = User.from_json_response(
                self.get_json_response("/me"),
                spotify_client=self,
            )

        return self._current_user

current_user: User property

Get the current user's info.

Returns:

Name Type Description
User User

an instance of the current Spotify user

add_tracks_to_playlist(tracks, playlist, *, log_responses=False, force_add=False, update_instance_tracklist=True)

Add one or more tracks to a playlist.

If force_add is False, a check is made against the Playlist's tracklist to ensure that the track isn't already in the playlist. This can be slow if it's a (new) big playlist.

Parameters:

Name Type Description Default
tracks list

a list of Track instances to be added to the given playlist

required
playlist Playlist

the playlist being updated

required
log_responses bool

log each individual response

False
force_add bool

flag for adding the track even if it's in the playlist already

False
update_instance_tracklist bool

appends the track to the Playlist's tracklist. Can be slow if it's a big playlist as it might have to get the tracklist first

True
Source code in wg_utilities/clients/spotify.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def add_tracks_to_playlist(
    self,
    tracks: Iterable[Track],
    playlist: Playlist,
    *,
    log_responses: bool = False,
    force_add: bool = False,
    update_instance_tracklist: bool = True,
) -> list[Track]:
    """Add one or more tracks to a playlist.

    If `force_add` is False, a check is made against the Playlist's tracklist to
    ensure that the track isn't already in the playlist. This can be slow if it's
    a (new) big playlist.

    Args:
        tracks (list): a list of Track instances to be added to the given playlist
        playlist (Playlist): the playlist being updated
        log_responses (bool): log each individual response
        force_add (bool): flag for adding the track even if it's in the playlist
            already
        update_instance_tracklist (bool): appends the track to the Playlist's
            tracklist. Can be slow if it's a big playlist as it might have to get
            the tracklist first
    """

    tracks_to_add = [
        track
        for track in tracks
        if not track.is_local and (force_add or track not in playlist)
    ]

    for chunk in chunk_list(tracks_to_add, 100):
        res = self._post(
            f"/playlists/{playlist.id}/tracks",
            json={"uris": [t.uri for t in chunk]},
        )

        if log_responses:
            LOGGER.info(dumps(res.json()))

    if update_instance_tracklist:
        playlist.tracks.extend(tracks_to_add)

    return tracks_to_add

create_playlist(*, name, description='', public=False, collaborative=False)

Create a new playlist under the current user's account.

Parameters:

Name Type Description Default
name str

the name of the new playlist

required
description str

the description of the new playlist

''
public bool

determines if the playlist is publicly accessible

False
collaborative bool

allows other people to add tracks when True

False

Returns:

Name Type Description
Playlist Playlist

an instance of the Playlist class containing the new playlist's metadata

Source code in wg_utilities/clients/spotify.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def create_playlist(
    self,
    *,
    name: str,
    description: str = "",
    public: bool = False,
    collaborative: bool = False,
) -> Playlist:
    """Create a new playlist under the current user's account.

    Args:
        name (str): the name of the new playlist
        description (str): the description of the new playlist
        public (bool): determines if the playlist is publicly accessible
        collaborative (bool): allows other people to add tracks when True

    Returns:
        Playlist: an instance of the Playlist class containing the new playlist's
            metadata
    """
    res = self._post(
        f"/users/{self.current_user.id}/playlists",
        json={
            "name": name,
            "description": description,
            "public": public,
            "collaborative": collaborative,
        },
    )

    return Playlist.from_json_response(res.json(), spotify_client=self)

get_album_by_id(id_)

Get an album from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the album

required

Returns:

Name Type Description
Album Album

an instantiated Album, from the API's response

Source code in wg_utilities/clients/spotify.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def get_album_by_id(self, id_: str) -> Album:
    """Get an album from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the album

    Returns:
        Album: an instantiated Album, from the API's response
    """

    return Album.from_json_response(
        self.get_json_response(f"/albums/{id_}"),
        spotify_client=self,
    )

get_artist_by_id(id_)

Get an artist from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the artist

required

Returns:

Name Type Description
Artist Artist

an instantiated Artist, from the API's response

Source code in wg_utilities/clients/spotify.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def get_artist_by_id(self, id_: str) -> Artist:
    """Get an artist from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the artist

    Returns:
        Artist: an instantiated Artist, from the API's response
    """

    return Artist.from_json_response(
        self.get_json_response(f"/artists/{id_}"),
        spotify_client=self,
    )

get_items(url, *, params=None, hard_limit=1000000, limit_func=None, top_level_key=None, list_key='items')

Retrieve a list of items from a given URL, including pagination.

Parameters:

Name Type Description Default
url str

the API endpoint which we're listing

required
params dict

any params to pass with the API request

None
hard_limit int

a hard limit to apply to the number of items returned (as opposed to the "soft" limit of 50 imposed by the API)

1000000
limit_func Callable

a function which is used to evaluate each item in turn: if it returns False, the item is added to the output list; if it returns True then the iteration stops and the list is returned as-is

None
top_level_key str

an optional key to use when the items in the response are nested 1 level deeper than normal

None
list_key Literal['devices', 'items']

the key in the response which contains the list of items

'items'

Returns:

Name Type Description
list list[SpotifyEntityJson]

a list of dicts representing the Spotify items

Source code in wg_utilities/clients/spotify.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def get_items(
    self,
    url: str,
    *,
    params: None | dict[str, str | int | float | bool | dict[str, Any]] = None,
    hard_limit: int = 1000000,
    limit_func: (
        Callable[
            [dict[str, Any] | SpotifyEntityJson],
            bool,
        ]
        | None
    ) = None,
    top_level_key: (
        Literal[
            "albums",
            "artists",
            "audiobooks",
            "episodes",
            "playlists",
            "shows",
            "tracks",
        ]
        | None
    ) = None,
    list_key: Literal["items", "devices"] = "items",
) -> list[SpotifyEntityJson]:
    """Retrieve a list of items from a given URL, including pagination.

    Args:
        url (str): the API endpoint which we're listing
        params (dict): any params to pass with the API request
        hard_limit (int): a hard limit to apply to the number of items returned (as
            opposed to the "soft" limit of 50 imposed by the API)
        limit_func (Callable): a function which is used to evaluate each item in
            turn: if it returns False, the item is added to the output list; if it
            returns True then the iteration stops and the list is returned as-is
        top_level_key (str): an optional key to use when the items in the response
            are nested 1 level deeper than normal
        list_key (Literal["devices", "items"]): the key in the response which
            contains the list of items

    Returns:
        list: a list of dicts representing the Spotify items
    """

    params = params or {}
    if "limit=" not in url:
        params["limit"] = min(50, hard_limit)

    items: list[SpotifyEntityJson] = []

    if params:
        url += ("?" if "?" not in url else "&") + urlencode(params)

    page: AnyPaginatedResponse = {
        "href": "",
        "items": [],
        "limit": 0,
        "next": url,
        "offset": 0,
        "total": 0,
    }

    while (next_url := page.get("next")) and len(items) < hard_limit:
        # Ensure we don't bother getting more items than we need
        limit = min(50, hard_limit - len(items))
        next_url = sub(r"(?<=limit=)(\d{1,2})(?=&?)", str(limit), next_url)

        res: SearchResponse | AnyPaginatedResponse = self.get_json_response(next_url)  # type: ignore[assignment]
        page = (
            cast(SearchResponse, res)[top_level_key]
            if top_level_key
            else cast(AnyPaginatedResponse, res)
        )

        page_items: (
            list[AlbumSummaryJson]
            | list[DeviceJson]
            | list[ArtistSummaryJson]
            | list[PlaylistSummaryJson]
            | list[TrackFullJson]
        ) = page.get(list_key, [])
        if limit_func is None:
            items.extend(page_items)
        else:
            # Initialise `limit_reached` to False, and then it will be set to
            # True on the first matching item. This will then cause the loop to
            # skip subsequent items - not as good as a `break` but still kind of
            # elegant imho!
            limit_reached = False
            items.extend(
                [
                    item
                    for item in page_items
                    if (not (limit_reached := (limit_reached or limit_func(item))))
                ],
            )
            if limit_reached:
                return items

    return items

get_playlist_by_id(id_)

Get a playlist from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the playlist

required

Returns:

Name Type Description
Playlist Playlist

an instantiated Playlist, from the API's response

Source code in wg_utilities/clients/spotify.py
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
def get_playlist_by_id(self, id_: str) -> Playlist:
    """Get a playlist from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the playlist

    Returns:
        Playlist: an instantiated Playlist, from the API's response
    """

    if hasattr(self, "_current_user") and hasattr(self.current_user, "_playlists"):
        for playlist in self.current_user.playlists:
            if playlist.id == id_:
                return playlist

    return Playlist.from_json_response(
        self.get_json_response(f"/playlists/{id_}"),
        spotify_client=self,
    )

get_track_by_id(id_)

Get a track from Spotify based on the ID.

Parameters:

Name Type Description Default
id_ str

the Spotify ID which is used to find the track

required

Returns:

Name Type Description
Track Track

an instantiated Track, from the API's response

Source code in wg_utilities/clients/spotify.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def get_track_by_id(self, id_: str) -> Track:
    """Get a track from Spotify based on the ID.

    Args:
        id_ (str): the Spotify ID which is used to find the track

    Returns:
        Track: an instantiated Track, from the API's response
    """

    return Track.from_json_response(
        self.get_json_response(f"/tracks/{id_}"),
        spotify_client=self,
    )

search(search_term, *, entity_types=(), get_best_match_only=False)

Search Spotify for a given search term.

Parameters:

Name Type Description Default
search_term str

the term to use as the base of the search

required
entity_types str

the types of entity to search for. Must be one of SpotifyClient.SEARCH_TYPES

()
get_best_match_only bool

return a single entity from the top of the list, rather than all matches

False

Returns:

Name Type Description
Artist | Playlist | Track | Album | None | ParsedSearchResponse

Artist | Playlist | Track | Album: a single entity if the best match flag is set

dict Artist | Playlist | Track | Album | None | ParsedSearchResponse

a dict of entities, by type

Raises:

Type Description
ValueError

if multiple entity types have been requested but the best match flag is true

ValueError

if one of entity_types is an invalid value

Source code in wg_utilities/clients/spotify.py
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
def search(
    self,
    search_term: str,
    *,
    entity_types: Sequence[Literal["album", "artist", "playlist", "track"]] = (),
    get_best_match_only: bool = False,
) -> Artist | Playlist | Track | Album | None | ParsedSearchResponse:
    """Search Spotify for a given search term.

    Args:
        search_term (str): the term to use as the base of the search
        entity_types (str): the types of entity to search for. Must be one of
            SpotifyClient.SEARCH_TYPES
        get_best_match_only (bool): return a single entity from the top of the
            list, rather than all matches

    Returns:
        Artist | Playlist | Track | Album: a single entity if the best match flag
            is set
        dict: a dict of entities, by type

    Raises:
        ValueError: if multiple entity types have been requested but the best match
            flag is true
        ValueError: if one of entity_types is an invalid value
    """

    entity_types = entity_types or self.SEARCH_TYPES

    if get_best_match_only is True and len(entity_types) != 1:
        raise ValueError(
            "Exactly one entity type must be requested if `get_best_match_only`"
            " is True",
        )

    entity_type: Literal["artist", "playlist", "track", "album"]
    for entity_type in entity_types:
        if entity_type not in self.SEARCH_TYPES:
            raise ValueError(
                f"Unexpected value for entity type: '{entity_type}'. Must be"
                f" one of {self.SEARCH_TYPES!r}",
            )

    res: SearchResponse = self.get_json_response(  # type: ignore[assignment]
        "/search",
        params={
            "query": search_term,
            "type": ",".join(entity_types),
            "limit": 1 if get_best_match_only else 50,
        },
    )

    entity_instances: ParsedSearchResponse = {}

    res_entity_type: Literal["albums", "artists", "playlists", "tracks"]
    entities_json: (
        PaginatedResponseAlbums
        | PaginatedResponseArtists
        | PaginatedResponsePlaylists
        | PaginatedResponseTracks
    )
    for res_entity_type, entities_json in res.items():  # type: ignore[assignment]
        instance_class: type[Album] | type[Artist] | type[Playlist] | type[Track] = {  # type: ignore[assignment]
            "albums": Album,
            "artists": Artist,
            "playlists": Playlist,
            "tracks": Track,
        }[res_entity_type]

        if get_best_match_only:
            try:
                # Take the entity off the top of the list
                return instance_class.from_json_response(
                    entities_json["items"][0],
                    spotify_client=self,
                )
            except LookupError:
                return None

        entity_instances.setdefault(res_entity_type, []).extend(
            [
                instance_class.from_json_response(entity_json, spotify_client=self)  # type: ignore[misc]
                for entity_json in entities_json.get("items", [])
            ],
        )

        # Each entity type has its own type-specific next URL
        if (next_url := entities_json.get("next")) is not None:
            entity_instances[res_entity_type].extend(
                [
                    instance_class.from_json_response(  # type: ignore[misc]
                        item,
                        spotify_client=self,
                    )
                    for item in self.get_items(
                        next_url,
                        top_level_key=res_entity_type,
                    )
                ],
            )

    return entity_instances

SpotifyEntity

Bases: BaseModelWithConfig, Generic[SJ]

Base model for Spotify entities.

Source code in wg_utilities/clients/spotify.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
class SpotifyEntity(BaseModelWithConfig, Generic[SJ]):
    """Base model for Spotify entities."""

    description: str = ""
    external_urls: dict[Literal["spotify"], str]
    href: str
    id: str
    name: str = ""
    uri: str

    metadata: dict[str, Any] = Field(default_factory=dict)
    spotify_client: SpotifyClient = Field(exclude=True)

    summary_json: SJ = Field(default_factory=dict, frozen=True, exclude=True)  # type: ignore[assignment]
    sj_type: ClassVar[TypeAlias] = SpotifyBaseEntityJson

    @model_validator(mode="before")
    @classmethod
    def _set_summary_json(cls, values: dict[str, Any]) -> Any:
        values["summary_json"] = {
            k: v for k, v in values.items() if k in cls.sj_type.__annotations__
        }

        # Playlists are a unique case, because the SummaryJson and the FullJson both
        # share the key "tracks", which means that if FullJson is passed into this,
        # it needs to be converted down to SummaryJson. PlaylistFullJson is verified
        # by checking for an offset value.
        if cls.sj_type.__name__ == "PlaylistSummaryJson" and "offset" in values[
            "summary_json"
        ].get("tracks", ()):
            values["summary_json"]["tracks"] = {
                "href": values["summary_json"]["tracks"]["href"],
                "total": values["summary_json"]["tracks"]["total"],
            }

        return values

    @classmethod
    def from_json_response(
        cls,
        value: SpotifyEntityJson,
        spotify_client: SpotifyClient,
        additional_fields: dict[str, Any] | None = None,
        metadata: dict[str, object] | None = None,
    ) -> Self:
        """Parse a JSON response from the API into the given entity type model.

        Args:
            value (dict[str, object]): the JSON response from the API
            spotify_client (SpotifyClient): the client to use for future API calls
            additional_fields (dict[str, object] | None): additional fields to add to
                the model
            metadata (dict[str, object] | None): additional metadata to add to the model

        Returns:
            SpotifyEntity: the model for the given entity type
        """

        value_data: dict[str, object] = {
            "spotify_client": spotify_client,
            **(additional_fields or {}),
            **value,
        }

        if metadata:
            value_data["metadata"] = metadata

        return cls.model_validate(value_data)

    @property
    def url(self) -> str:
        """URL of the entity.

        Returns:
            str: the URL of this entity
        """
        return self.external_urls.get(
            "spotify",
            f"https://open.spotify.com/{type(self).__name__.lower()}/{self.id}",
        )

    def __eq__(self, other: object) -> bool:
        """Check if two entities are equal."""
        if not isinstance(other, SpotifyEntity):
            return NotImplemented
        return self.uri == other.uri

    def __gt__(self, other: object) -> bool:
        """Check if this entity is greater than another."""
        if not isinstance(other, SpotifyEntity):
            return NotImplemented
        return (self.name or self.id).lower() > (other.name or other.id).lower()

    def __hash__(self) -> int:
        """Get the hash of this entity."""
        return hash(repr(self))

    def __lt__(self, other: SpotifyEntity[SJ]) -> bool:
        """Check if this entity is less than another."""
        if not isinstance(other, SpotifyEntity):
            return NotImplemented
        return (self.name or self.id).lower() < (other.name or other.id).lower()

    def __repr__(self) -> str:
        """Get a string representation of this entity."""
        return f'{type(self).__name__}(id="{self.id}", name="{self.name}")'

    def __str__(self) -> str:
        """Get the string representation of this entity."""
        return self.name or f"{type(self).__name__} ({self.id})"

url: str property

URL of the entity.

Returns:

Name Type Description
str str

the URL of this entity

__eq__(other)

Check if two entities are equal.

Source code in wg_utilities/clients/spotify.py
633
634
635
636
637
def __eq__(self, other: object) -> bool:
    """Check if two entities are equal."""
    if not isinstance(other, SpotifyEntity):
        return NotImplemented
    return self.uri == other.uri

__gt__(other)

Check if this entity is greater than another.

Source code in wg_utilities/clients/spotify.py
639
640
641
642
643
def __gt__(self, other: object) -> bool:
    """Check if this entity is greater than another."""
    if not isinstance(other, SpotifyEntity):
        return NotImplemented
    return (self.name or self.id).lower() > (other.name or other.id).lower()

__hash__()

Get the hash of this entity.

Source code in wg_utilities/clients/spotify.py
645
646
647
def __hash__(self) -> int:
    """Get the hash of this entity."""
    return hash(repr(self))

__lt__(other)

Check if this entity is less than another.

Source code in wg_utilities/clients/spotify.py
649
650
651
652
653
def __lt__(self, other: SpotifyEntity[SJ]) -> bool:
    """Check if this entity is less than another."""
    if not isinstance(other, SpotifyEntity):
        return NotImplemented
    return (self.name or self.id).lower() < (other.name or other.id).lower()

__repr__()

Get a string representation of this entity.

Source code in wg_utilities/clients/spotify.py
655
656
657
def __repr__(self) -> str:
    """Get a string representation of this entity."""
    return f'{type(self).__name__}(id="{self.id}", name="{self.name}")'

__str__()

Get the string representation of this entity.

Source code in wg_utilities/clients/spotify.py
659
660
661
def __str__(self) -> str:
    """Get the string representation of this entity."""
    return self.name or f"{type(self).__name__} ({self.id})"

from_json_response(value, spotify_client, additional_fields=None, metadata=None) classmethod

Parse a JSON response from the API into the given entity type model.

Parameters:

Name Type Description Default
value dict[str, object]

the JSON response from the API

required
spotify_client SpotifyClient

the client to use for future API calls

required
additional_fields dict[str, object] | None

additional fields to add to the model

None
metadata dict[str, object] | None

additional metadata to add to the model

None

Returns:

Name Type Description
SpotifyEntity Self

the model for the given entity type

Source code in wg_utilities/clients/spotify.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
@classmethod
def from_json_response(
    cls,
    value: SpotifyEntityJson,
    spotify_client: SpotifyClient,
    additional_fields: dict[str, Any] | None = None,
    metadata: dict[str, object] | None = None,
) -> Self:
    """Parse a JSON response from the API into the given entity type model.

    Args:
        value (dict[str, object]): the JSON response from the API
        spotify_client (SpotifyClient): the client to use for future API calls
        additional_fields (dict[str, object] | None): additional fields to add to
            the model
        metadata (dict[str, object] | None): additional metadata to add to the model

    Returns:
        SpotifyEntity: the model for the given entity type
    """

    value_data: dict[str, object] = {
        "spotify_client": spotify_client,
        **(additional_fields or {}),
        **value,
    }

    if metadata:
        value_data["metadata"] = metadata

    return cls.model_validate(value_data)

Track

Bases: SpotifyEntity[TrackFullJson]

A track on Spotify.

Source code in wg_utilities/clients/spotify.py
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
class Track(SpotifyEntity[TrackFullJson]):
    """A track on Spotify."""

    album_json: AlbumSummaryJson = Field(alias="album")
    artists_json: list[ArtistSummaryJson] = Field(alias="artists")
    audio_features_json: TrackAudioFeaturesJson | None = Field(
        None,
        alias="audio_features",
    )
    available_markets: list[str]
    disc_number: int
    duration_ms: int
    episode: bool | None = None
    explicit: bool
    external_ids: dict[str, str] | None = None
    is_local: bool
    is_playable: bool | None = None
    linked_from: TrackFullJson | None = None
    popularity: int | None = None
    preview_url: str | None = None
    restrictions: str | None = None
    track: bool | None = None
    track_number: int
    type: Literal["track"]

    _artists: list[Artist]
    _album: Album
    _audio_features: TrackAudioFeatures | None

    sj_type: ClassVar[SpotifyEntityJsonType] = TrackFullJson

    @property
    def album(self) -> Album:
        """Track's parent album.

        Returns:
            Album: the album which this track is from
        """

        if not hasattr(self, "_album"):
            self._album = Album.from_json_response(
                self.album_json,
                spotify_client=self.spotify_client,
            )

        return self._album

    @property
    def artist(self) -> Artist:
        """Track's parent artist.

        Returns:
            Artist: the main artist which this track is from
        """

        return self.artists[0]

    @property
    def artists(self) -> list[Artist]:
        """Return a list of artists who contributed to the track.

        Returns:
            list(Artist): a list of the artists who contributed to this track
        """

        if not hasattr(self, "_artists"):
            artists = [
                Artist.from_json_response(
                    item,
                    spotify_client=self.spotify_client,
                )
                for item in self.artists_json
            ]

            self._artists = artists

        return self._artists

    @property
    def audio_features(self) -> TrackAudioFeatures | None:
        """Audio features of the track.

        Returns:
            dict: the JSON response from the Spotify /audio-features endpoint

        Raises:
            HTTPError: if `get_json_response` throws a HTTPError for a non-200/404
                response
        """
        if not hasattr(self, "_audio_features"):
            try:
                audio_features = self.spotify_client.get_json_response(
                    f"/audio-features/{self.id}",
                )
            except HTTPError as exc:
                if (
                    exc.response is not None
                    and exc.response.status_code == HTTPStatus.NOT_FOUND
                ):
                    return None
                raise

            self._audio_features = TrackAudioFeatures(**audio_features)

        return self._audio_features

    @property
    def release_date(self) -> date:
        """Album release date.

        Returns:
            date: the date the track's album was first released
        """
        return self.album.release_date

    @property
    def tempo(self) -> float | None:
        """Tempo of the track in BPM.

        Returns:
            float: the tempo of the track, in BPM
        """
        try:
            return self.audio_features.tempo  # type: ignore[union-attr]
        except AttributeError:
            return None

album: Album property

Track's parent album.

Returns:

Name Type Description
Album Album

the album which this track is from

artist: Artist property

Track's parent artist.

Returns:

Name Type Description
Artist Artist

the main artist which this track is from

artists: list[Artist] property

Return a list of artists who contributed to the track.

Returns:

Name Type Description
list Artist

a list of the artists who contributed to this track

audio_features: TrackAudioFeatures | None property

Audio features of the track.

Returns:

Name Type Description
dict TrackAudioFeatures | None

the JSON response from the Spotify /audio-features endpoint

Raises:

Type Description
HTTPError

if get_json_response throws a HTTPError for a non-200/404 response

release_date: date property

Album release date.

Returns:

Name Type Description
date date

the date the track's album was first released

tempo: float | None property

Tempo of the track in BPM.

Returns:

Name Type Description
float float | None

the tempo of the track, in BPM

TrackAudioFeatures

Bases: BaseModelWithConfig

Audio feature information for a single track.

Source code in wg_utilities/clients/spotify.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class TrackAudioFeatures(BaseModelWithConfig):
    """Audio feature information for a single track."""

    acousticness: float
    analysis_url: str
    danceability: float
    duration_ms: int
    energy: float
    id: str
    instrumentalness: float
    key: int
    liveness: float
    loudness: float
    mode: int
    speechiness: float
    tempo: float
    time_signature: int
    track_href: str
    type: Literal["audio_features"]
    uri: str
    valence: float

User

Bases: SpotifyEntity[UserSummaryJson]

A Spotify user, usually just the current user.

Source code in wg_utilities/clients/spotify.py
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
class User(SpotifyEntity[UserSummaryJson]):
    """A Spotify user, usually just the current user."""

    PLAYLIST_REFRESH_INTERVAL: ClassVar[timedelta] = timedelta(minutes=10)

    display_name: str
    country: str | None = None
    email: str | None = None
    explicit_content: dict[str, bool] | None = None
    followers: Followers | None = None
    images: list[Image] | None = None
    product: str | None = None
    type: Literal["user"]

    _albums: list[Album]
    _artists: list[Artist]
    _playlists: list[Playlist]
    _top_artists: tuple[Artist, ...]
    _top_tracks: tuple[Track, ...]
    _tracks: list[Track]

    _playlist_refresh_time: datetime

    sj_type: ClassVar[SpotifyEntityJsonType] = UserSummaryJson

    @field_validator("display_name", mode="before")
    @classmethod
    def set_user_name_value(cls, value: str, info: ValidationInfo) -> str:
        """Set the user's `name` field to the display name if it is not set.

        Args:
            value (str): the display name
            info (ValidationInfo): Object for extra validation information/data.

        Returns:
            str: the display name
        """

        if not info.data.get("name"):
            info.data["name"] = value

        return value

    @overload
    def get_playlists_by_name(
        self,
        name: str,
        *,
        return_all: Literal[False] = False,
    ) -> Playlist | None: ...

    @overload
    def get_playlists_by_name(
        self,
        name: str,
        *,
        return_all: Literal[True],
    ) -> list[Playlist]: ...

    def get_playlists_by_name(
        self,
        name: str,
        *,
        return_all: bool = False,
    ) -> list[Playlist] | Playlist | None:
        """Get Playlist instance(s) which have the given name.

        Args:
            name (str): the name of the target playlist(s)
            return_all (bool): playlist names aren't unique - but most people keep them
                unique within their own Sequence of playlists. This boolean can be used
                to return either a list of all matching playlists, or just the single
                found playlist

        Returns:
            Union([list, Playlist]): the matched playlist(s)
        """

        matched_playlists = filter(
            lambda p: p.name.lower() == name.lower(),
            self.playlists,
        )

        # Return a list of all matches
        if return_all:
            return sorted(matched_playlists)

        try:
            return next(matched_playlists)
        except StopIteration:
            return None

    def get_recently_liked_tracks(
        self,
        track_limit: int = 100,
        *,
        day_limit: float = 0.0,
    ) -> list[Track]:
        """Get a list of songs which were liked by the current user in the past N days.

        Args:
            track_limit (int): the number of tracks to return
            day_limit (float): the number of days (N) to go back in time for

        Returns:
            list: a list of Track instances
        """

        if not day_limit:
            limit_func: (
                Callable[
                    [SpotifyEntityJson | dict[str, Any]],
                    bool,
                ]
                | None
            ) = None

        else:

            def limit_func(item: dict[str, Any]) -> bool:  # type: ignore[misc]
                return bool(
                    datetime.strptime(
                        item["added_at"],
                        self.spotify_client.DATETIME_FORMAT,
                    ).replace(tzinfo=UTC)
                    < (datetime.now(UTC) - timedelta(days=day_limit)),
                )

        return [
            Track.from_json_response(
                item["track"],
                spotify_client=self.spotify_client,
                metadata={
                    "saved_at": datetime.strptime(
                        item["added_at"],
                        self.spotify_client.DATETIME_FORMAT,
                    ).replace(tzinfo=UTC),
                },
            )
            for item in cast(
                list[SavedItem],
                self.spotify_client.get_items(
                    "/me/tracks",
                    hard_limit=track_limit,
                    limit_func=limit_func,
                ),
            )
        ]

    def save(self, entity: Album | Artist | Playlist | Track) -> None:
        """Save an entity to the user's library.

        Args:
            entity (Album|Artist|Playlist|Track): the entity to save

        Raises:
            TypeError: if the entity is not of a supported type
        """

        if isinstance(entity, Album):
            url = f"{self.spotify_client.BASE_URL}/me/albums"
            params = {"ids": entity.id}
        elif isinstance(entity, Artist):
            url = f"{self.spotify_client.BASE_URL}/me/following"
            params = {"type": "artist", "ids": entity.id}
        elif isinstance(entity, Playlist):
            url = f"{self.spotify_client.BASE_URL}/playlists/{entity.id}/followers"
            params = {"ids": self.id}
        elif isinstance(entity, Track):
            url = f"{self.spotify_client.BASE_URL}/me/tracks"
            params = {"ids": entity.id}
        else:
            raise TypeError(
                f"Cannot save entity of type `{type(entity).__name__}`. "
                f"Must be one of: Album, Artist, Playlist, Track",
            )

        res = put(
            url,
            params=params,
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.spotify_client.access_token}",
                "Host": "api.spotify.com",
            },
            timeout=10,
        )
        res.raise_for_status()

    def unsave(self, entity: Album | Artist | Playlist | Track) -> None:
        """Remove an entity from the user's library.

        Args:
            entity (Album|Artist|Playlist|Track): the entity to remove

        Raises:
            TypeError: if the entity is not of a supported type
        """

        if isinstance(entity, Album):
            url = f"{self.spotify_client.BASE_URL}/me/albums"
            params = {"ids": entity.id}
        elif isinstance(entity, Artist):
            url = f"{self.spotify_client.BASE_URL}/me/following"
            params = {"type": "artist", "ids": entity.id}
        elif isinstance(entity, Playlist):
            url = f"{self.spotify_client.BASE_URL}/playlists/{entity.id}/followers"
            params = {"ids": self.id}
        elif isinstance(entity, Track):
            url = f"{self.spotify_client.BASE_URL}/me/tracks"
            params = {"ids": entity.id}
        else:
            raise TypeError(
                f"Cannot unsave entity of type `{type(entity).__name__}`. "
                f"Must be one of: Album, Artist, Playlist, Track",
            )

        res = delete(
            url,
            params=params,
            headers={
                "Content-Type": "application/json",
                "Authorization": f"Bearer {self.spotify_client.access_token}",
                "Host": "api.spotify.com",
            },
            timeout=10,
        )
        if res.status_code != HTTPStatus.BAD_REQUEST:
            res.raise_for_status()

    @property
    def albums(self) -> list[Album]:
        """List of albums in the user's library.

        Returns:
            list: a list of albums owned by the current user
        """

        if not hasattr(self, "_albums"):
            albums = [
                Album.from_json_response(
                    item["album"],
                    spotify_client=self.spotify_client,
                )
                for item in cast(
                    list[SavedItem],
                    self.spotify_client.get_items("/me/albums"),
                )
            ]

            self._albums = albums

        return self._albums

    @property
    def artists(self) -> list[Artist]:
        """List of artists in the user's library.

        Returns:
            list: a list of artists owned by the current user
        """

        if not hasattr(self, "_artists"):
            artists = [
                Artist.from_json_response(
                    artist_json,
                    spotify_client=self.spotify_client,
                )
                for artist_json in self.spotify_client.get_items(
                    "/me/following",
                    params={
                        "type": "artist",
                    },
                    top_level_key="artists",
                )
            ]

            self._artists = artists

        return self._artists

    @property
    def current_track(self) -> Track | None:
        """Get the currently playing track for the given user.

        Returns:
            Track: the track currently being listened to
        """

        res = cast(
            SavedItem,
            self.spotify_client.get_json_response("/me/player/currently-playing"),
        )

        if item := res.get("item"):
            return Track.from_json_response(item, spotify_client=self.spotify_client)

        return None

    @property
    def current_playlist(self) -> Playlist | None:
        """Get the current playlist for the given user.

        Returns:
            Playlist: the playlist currently being listened to
        """

        res = self.spotify_client.get_json_response("/me/player/currently-playing")

        if (context := res.get("context", {})).get(  # type: ignore[attr-defined]
            "type",
        ) == "playlist":
            playlist: Playlist = self.spotify_client.get_playlist_by_id(
                context["uri"].split(":")[-1],  # type: ignore[index]
            )
            return playlist
        return None

    @property
    def devices(self) -> list[Device]:
        """Return a list of devices that the user currently has access to.

        Returns:
            list[Device]: a list of devices available to the user
        """
        return [
            Device.model_validate(device_json)
            for device_json in self.spotify_client.get_items(
                "/me/player/devices",
                list_key="devices",
            )
        ]

    @property
    def playlists(self) -> list[Playlist]:
        """Return a list of playlists owned by the current user.

        If self.PLAYLIST_REFRESH_INTERVAL has elapsed, a new call to the API will be
        made to refresh the list of playlists. Only new playlists will be added to the
        list, preserving previous instances.

        Returns:
            list: a list of playlists owned by the current user
        """

        if (
            hasattr(self, "_playlist_refresh_time")
            and (datetime.now(UTC) - self._playlist_refresh_time)
            < self.PLAYLIST_REFRESH_INTERVAL
        ):
            return self._playlists

        self._playlist_refresh_time = datetime.now(UTC)

        all_playlist_json = cast(
            list[PlaylistSummaryJson],
            self.spotify_client.get_items("/me/playlists"),
        )

        if not hasattr(self, "_playlists"):
            playlists = [
                Playlist.from_json_response(item, spotify_client=self.spotify_client)
                for item in all_playlist_json
                if item["owner"]["id"] == self.id
            ]

            self._playlists = playlists
        else:
            existing_ids = (p.id for p in self._playlists)
            new_playlists = [
                Playlist.from_json_response(item, spotify_client=self.spotify_client)
                for item in all_playlist_json
                if item["owner"]["id"] == self.id and item["id"] not in existing_ids
            ]

            self._playlists.extend(new_playlists)

        return self._playlists

    @property
    def top_artists(self) -> tuple[Artist, ...]:
        """Top artists for the user.

        Returns:
            tuple[Artist, ...]: the top artists for the user
        """

        if not hasattr(self, "_top_artists"):
            top_artists = tuple(
                Artist.from_json_response(
                    artist_json,
                    spotify_client=self.spotify_client,
                )
                for artist_json in self.spotify_client.get_items(
                    "/me/top/artists",
                    params={"time_range": "short_term"},
                )
            )
            self._top_artists = top_artists

        return self._top_artists

    @property
    def top_tracks(self) -> tuple[Track, ...]:
        """The top tracks for the user.

        Returns:
            tuple[Track]: the top tracks for the user
        """
        if not hasattr(self, "_top_tracks"):
            top_tracks = tuple(
                Track.from_json_response(
                    track_json,
                    spotify_client=self.spotify_client,
                )
                for track_json in self.spotify_client.get_items(
                    "/me/top/tracks",
                    params={"time_range": "short_term"},
                )
            )

            self._top_tracks = top_tracks

        return self._top_tracks

    @property
    def tracks(self) -> list[Track]:
        """Liked Songs.

        Returns:
            list: a list of tracks owned by the current user
        """

        if not hasattr(self, "_tracks"):
            tracks = [
                Track.from_json_response(
                    item["track"],
                    spotify_client=self.spotify_client,
                    metadata={
                        "saved_at": datetime.strptime(
                            item["added_at"],
                            self.spotify_client.DATETIME_FORMAT,
                        ).replace(tzinfo=UTC),
                    },
                )
                for item in cast(
                    list[SavedItem],
                    self.spotify_client.get_items("/me/tracks"),
                )
            ]

            self._tracks = tracks

        return self._tracks

    def reset_properties(
        self,
        property_names: (
            Iterable[
                Literal[
                    "albums",
                    "artists",
                    "playlists",
                    "top_artists",
                    "top_tracks",
                    "tracks",
                ]
            ]
            | None
        ) = None,
    ) -> None:
        """Reset all list properties."""

        property_names = property_names or [
            "albums",
            "artists",
            "playlists",
            "top_artists",
            "top_tracks",
            "tracks",
        ]

        for prop_name in property_names:
            if hasattr(self, attr_name := f"_{prop_name}"):
                delattr(self, attr_name)

        if "playlists" in property_names:
            delattr(self, "_playlist_refresh_time")

albums: list[Album] property

List of albums in the user's library.

Returns:

Name Type Description
list list[Album]

a list of albums owned by the current user

artists: list[Artist] property

List of artists in the user's library.

Returns:

Name Type Description
list list[Artist]

a list of artists owned by the current user

current_playlist: Playlist | None property

Get the current playlist for the given user.

Returns:

Name Type Description
Playlist Playlist | None

the playlist currently being listened to

current_track: Track | None property

Get the currently playing track for the given user.

Returns:

Name Type Description
Track Track | None

the track currently being listened to

devices: list[Device] property

Return a list of devices that the user currently has access to.

Returns:

Type Description
list[Device]

list[Device]: a list of devices available to the user

playlists: list[Playlist] property

Return a list of playlists owned by the current user.

If self.PLAYLIST_REFRESH_INTERVAL has elapsed, a new call to the API will be made to refresh the list of playlists. Only new playlists will be added to the list, preserving previous instances.

Returns:

Name Type Description
list list[Playlist]

a list of playlists owned by the current user

top_artists: tuple[Artist, ...] property

Top artists for the user.

Returns:

Type Description
tuple[Artist, ...]

tuple[Artist, ...]: the top artists for the user

top_tracks: tuple[Track, ...] property

The top tracks for the user.

Returns:

Type Description
tuple[Track, ...]

tuple[Track]: the top tracks for the user

tracks: list[Track] property

Liked Songs.

Returns:

Name Type Description
list list[Track]

a list of tracks owned by the current user

get_playlists_by_name(name, *, return_all=False)

Get Playlist instance(s) which have the given name.

Parameters:

Name Type Description Default
name str

the name of the target playlist(s)

required
return_all bool

playlist names aren't unique - but most people keep them unique within their own Sequence of playlists. This boolean can be used to return either a list of all matching playlists, or just the single found playlist

False

Returns:

Name Type Description
Union [list, Playlist]

the matched playlist(s)

Source code in wg_utilities/clients/spotify.py
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
def get_playlists_by_name(
    self,
    name: str,
    *,
    return_all: bool = False,
) -> list[Playlist] | Playlist | None:
    """Get Playlist instance(s) which have the given name.

    Args:
        name (str): the name of the target playlist(s)
        return_all (bool): playlist names aren't unique - but most people keep them
            unique within their own Sequence of playlists. This boolean can be used
            to return either a list of all matching playlists, or just the single
            found playlist

    Returns:
        Union([list, Playlist]): the matched playlist(s)
    """

    matched_playlists = filter(
        lambda p: p.name.lower() == name.lower(),
        self.playlists,
    )

    # Return a list of all matches
    if return_all:
        return sorted(matched_playlists)

    try:
        return next(matched_playlists)
    except StopIteration:
        return None

get_recently_liked_tracks(track_limit=100, *, day_limit=0.0)

Get a list of songs which were liked by the current user in the past N days.

Parameters:

Name Type Description Default
track_limit int

the number of tracks to return

100
day_limit float

the number of days (N) to go back in time for

0.0

Returns:

Name Type Description
list list[Track]

a list of Track instances

Source code in wg_utilities/clients/spotify.py
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
def get_recently_liked_tracks(
    self,
    track_limit: int = 100,
    *,
    day_limit: float = 0.0,
) -> list[Track]:
    """Get a list of songs which were liked by the current user in the past N days.

    Args:
        track_limit (int): the number of tracks to return
        day_limit (float): the number of days (N) to go back in time for

    Returns:
        list: a list of Track instances
    """

    if not day_limit:
        limit_func: (
            Callable[
                [SpotifyEntityJson | dict[str, Any]],
                bool,
            ]
            | None
        ) = None

    else:

        def limit_func(item: dict[str, Any]) -> bool:  # type: ignore[misc]
            return bool(
                datetime.strptime(
                    item["added_at"],
                    self.spotify_client.DATETIME_FORMAT,
                ).replace(tzinfo=UTC)
                < (datetime.now(UTC) - timedelta(days=day_limit)),
            )

    return [
        Track.from_json_response(
            item["track"],
            spotify_client=self.spotify_client,
            metadata={
                "saved_at": datetime.strptime(
                    item["added_at"],
                    self.spotify_client.DATETIME_FORMAT,
                ).replace(tzinfo=UTC),
            },
        )
        for item in cast(
            list[SavedItem],
            self.spotify_client.get_items(
                "/me/tracks",
                hard_limit=track_limit,
                limit_func=limit_func,
            ),
        )
    ]

reset_properties(property_names=None)

Reset all list properties.

Source code in wg_utilities/clients/spotify.py
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
def reset_properties(
    self,
    property_names: (
        Iterable[
            Literal[
                "albums",
                "artists",
                "playlists",
                "top_artists",
                "top_tracks",
                "tracks",
            ]
        ]
        | None
    ) = None,
) -> None:
    """Reset all list properties."""

    property_names = property_names or [
        "albums",
        "artists",
        "playlists",
        "top_artists",
        "top_tracks",
        "tracks",
    ]

    for prop_name in property_names:
        if hasattr(self, attr_name := f"_{prop_name}"):
            delattr(self, attr_name)

    if "playlists" in property_names:
        delattr(self, "_playlist_refresh_time")

save(entity)

Save an entity to the user's library.

Parameters:

Name Type Description Default
entity Album | Artist | Playlist | Track

the entity to save

required

Raises:

Type Description
TypeError

if the entity is not of a supported type

Source code in wg_utilities/clients/spotify.py
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
def save(self, entity: Album | Artist | Playlist | Track) -> None:
    """Save an entity to the user's library.

    Args:
        entity (Album|Artist|Playlist|Track): the entity to save

    Raises:
        TypeError: if the entity is not of a supported type
    """

    if isinstance(entity, Album):
        url = f"{self.spotify_client.BASE_URL}/me/albums"
        params = {"ids": entity.id}
    elif isinstance(entity, Artist):
        url = f"{self.spotify_client.BASE_URL}/me/following"
        params = {"type": "artist", "ids": entity.id}
    elif isinstance(entity, Playlist):
        url = f"{self.spotify_client.BASE_URL}/playlists/{entity.id}/followers"
        params = {"ids": self.id}
    elif isinstance(entity, Track):
        url = f"{self.spotify_client.BASE_URL}/me/tracks"
        params = {"ids": entity.id}
    else:
        raise TypeError(
            f"Cannot save entity of type `{type(entity).__name__}`. "
            f"Must be one of: Album, Artist, Playlist, Track",
        )

    res = put(
        url,
        params=params,
        headers={
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.spotify_client.access_token}",
            "Host": "api.spotify.com",
        },
        timeout=10,
    )
    res.raise_for_status()

set_user_name_value(value, info) classmethod

Set the user's name field to the display name if it is not set.

Parameters:

Name Type Description Default
value str

the display name

required
info ValidationInfo

Object for extra validation information/data.

required

Returns:

Name Type Description
str str

the display name

Source code in wg_utilities/clients/spotify.py
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
@field_validator("display_name", mode="before")
@classmethod
def set_user_name_value(cls, value: str, info: ValidationInfo) -> str:
    """Set the user's `name` field to the display name if it is not set.

    Args:
        value (str): the display name
        info (ValidationInfo): Object for extra validation information/data.

    Returns:
        str: the display name
    """

    if not info.data.get("name"):
        info.data["name"] = value

    return value

unsave(entity)

Remove an entity from the user's library.

Parameters:

Name Type Description Default
entity Album | Artist | Playlist | Track

the entity to remove

required

Raises:

Type Description
TypeError

if the entity is not of a supported type

Source code in wg_utilities/clients/spotify.py
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
def unsave(self, entity: Album | Artist | Playlist | Track) -> None:
    """Remove an entity from the user's library.

    Args:
        entity (Album|Artist|Playlist|Track): the entity to remove

    Raises:
        TypeError: if the entity is not of a supported type
    """

    if isinstance(entity, Album):
        url = f"{self.spotify_client.BASE_URL}/me/albums"
        params = {"ids": entity.id}
    elif isinstance(entity, Artist):
        url = f"{self.spotify_client.BASE_URL}/me/following"
        params = {"type": "artist", "ids": entity.id}
    elif isinstance(entity, Playlist):
        url = f"{self.spotify_client.BASE_URL}/playlists/{entity.id}/followers"
        params = {"ids": self.id}
    elif isinstance(entity, Track):
        url = f"{self.spotify_client.BASE_URL}/me/tracks"
        params = {"ids": entity.id}
    else:
        raise TypeError(
            f"Cannot unsave entity of type `{type(entity).__name__}`. "
            f"Must be one of: Album, Artist, Playlist, Track",
        )

    res = delete(
        url,
        params=params,
        headers={
            "Content-Type": "application/json",
            "Authorization": f"Bearer {self.spotify_client.access_token}",
            "Host": "api.spotify.com",
        },
        timeout=10,
    )
    if res.status_code != HTTPStatus.BAD_REQUEST:
        res.raise_for_status()

truelayer

Custom client for interacting with TrueLayer's API.

Account

Bases: TrueLayerEntity

Class for managing individual bank accounts.

Source code in wg_utilities/clients/truelayer.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
class Account(TrueLayerEntity):
    """Class for managing individual bank accounts."""

    BALANCE_FIELDS: ClassVar[Iterable[str]] = (
        "available_balance",
        "current_balance",
        "overdraft",
    )
    account_number: _AccountNumber
    account_type: AccountType

    _overdraft: float

    @field_validator("account_type", mode="before")
    @classmethod
    def validate_account_type(cls, value: str) -> AccountType:
        """Validate `account_type` and parse it into an `AccountType` instance."""
        if isinstance(value, AccountType):
            return value

        if value not in AccountType.__members__:  # pragma: no cover
            raise ValueError(f"Invalid account type: `{value}`")

        return AccountType[value.upper()]

    @property
    def overdraft(self) -> float | None:
        """Overdraft limit for the account.

        Returns:
            float: the overdraft limit of the account
        """
        return self._get_balance_property("overdraft")

overdraft: float | None property

Overdraft limit for the account.

Returns:

Name Type Description
float float | None

the overdraft limit of the account

validate_account_type(value) classmethod

Validate account_type and parse it into an AccountType instance.

Source code in wg_utilities/clients/truelayer.py
465
466
467
468
469
470
471
472
473
474
475
@field_validator("account_type", mode="before")
@classmethod
def validate_account_type(cls, value: str) -> AccountType:
    """Validate `account_type` and parse it into an `AccountType` instance."""
    if isinstance(value, AccountType):
        return value

    if value not in AccountType.__members__:  # pragma: no cover
        raise ValueError(f"Invalid account type: `{value}`")

    return AccountType[value.upper()]

AccountJson

Bases: _TrueLayerBaseEntityJson

JSON representation of a TrueLayer Account.

Source code in wg_utilities/clients/truelayer.py
153
154
155
156
157
class AccountJson(_TrueLayerBaseEntityJson):
    """JSON representation of a TrueLayer Account."""

    account_number: _AccountNumber
    account_type: AccountType

AccountType

Bases: StrEnum

Possible TrueLayer account types.

Source code in wg_utilities/clients/truelayer.py
35
36
37
38
39
40
41
class AccountType(StrEnum):
    """Possible TrueLayer account types."""

    TRANSACTION = auto()
    SAVINGS = auto()
    BUSINESS_TRANSACTION = auto()
    BUSINESS_SAVINGS = auto()

BalanceVariables

Bases: BaseModelWithConfig

Variables for an account's balance summary.

Source code in wg_utilities/clients/truelayer.py
171
172
173
174
175
176
177
178
179
180
181
class BalanceVariables(BaseModelWithConfig):
    """Variables for an account's balance summary."""

    available_balance: int
    current_balance: int
    overdraft: int
    credit_limit: int
    last_statement_balance: int
    last_statement_date: date
    payment_due: int
    payment_due_date: date

Bank

Bases: StrEnum

Enum for all banks supported by TrueLayer.

Source code in wg_utilities/clients/truelayer.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class Bank(StrEnum):
    """Enum for all banks supported by TrueLayer."""

    ALLIED_IRISH_BANK_CORPORATE = "Allied Irish Bank Corporate"
    AMEX = "Amex"
    BANK_OF_SCOTLAND = "Bank of Scotland"
    BANK_OF_SCOTLAND_BUSINESS = "Bank of Scotland Business"
    BARCLAYCARD = "Barclaycard"
    BARCLAYS = "Barclays"
    BARCLAYS_BUSINESS = "Barclays Business"
    CAPITAL_ONE = "Capital One"
    CHELSEA_BUILDING_SOCIETY = "Chelsea Building Society"
    DANSKE_BANK = "Danske Bank"
    DANSKE_BANK_BUSINESS = "Danske Bank Business"
    FIRST_DIRECT = "First Direct"
    HALIFAX = "Halifax"
    HSBC = "HSBC"
    HSBC_BUSINESS = "HSBC Business"
    LLOYDS = "Lloyds"
    LLOYDS_BUSINESS = "Lloyds Business"
    LLOYDS_COMMERCIAL = "Lloyds Commercial"
    M_S_BANK = "M&S Bank"
    MBNA = "MBNA"
    MONZO = "Monzo"
    NATIONWIDE = "Nationwide"
    NATWEST = "NatWest"
    NATWEST_BUSINESS = "NatWest Business"
    REVOLUT = "Revolut"
    ROYAL_BANK_OF_SCOTLAND = "Royal Bank of Scotland"
    ROYAL_BANK_OF_SCOTLAND_BUSINESS = "Royal Bank of Scotland Business"
    SANTANDER = "Santander"
    STARLING = "Starling"
    STARLING_JOINT = "Starling Joint"
    TESCO_BANK = "Tesco Bank"
    TIDE = "Tide"
    TSB = "TSB"
    ULSTER_BANK = "Ulster Bank"
    ULSTER_BUSINESS = "Ulster Business"
    VIRGIN_MONEY = "Virgin Money"
    WISE = "Wise"
    YORKSHIRE_BUILDING_SOCIETY = "Yorkshire Building Society"

Card

Bases: TrueLayerEntity

Class for managing individual cards.

Source code in wg_utilities/clients/truelayer.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
class Card(TrueLayerEntity):
    """Class for managing individual cards."""

    BALANCE_FIELDS: ClassVar[Iterable[str]] = (
        "available_balance",
        "current_balance",
        "credit_limit",
        "last_statement_balance",
        "last_statement_date",
        "payment_due",
        "payment_due_date",
    )

    card_network: str
    card_type: str
    partial_card_number: str
    name_on_card: str
    valid_from: date | None = None
    valid_to: date | None = None

    _credit_limit: float
    _last_statement_balance: float
    _last_statement_date: date
    _payment_due: float
    _payment_due_date: date

    @property
    def credit_limit(self) -> float | None:
        """Credit limit of the account.

        Returns:
            float: the credit limit available to the customer
        """
        return self._get_balance_property("credit_limit")

    @property
    def last_statement_balance(self) -> float | None:
        """Balance of the account at the last statement date.

        Returns:
            float: the balance on the last statement
        """
        return self._get_balance_property("last_statement_balance")

    @property
    def last_statement_date(self) -> date | None:
        """Date of the last statement.

        Returns:
            date: the date the last statement was issued on
        """
        return self._get_balance_property("last_statement_date")

    @property
    def payment_due(self) -> float | None:
        """Amount due on the next statement.

        Returns:
            float: the amount of any due payment
        """
        return self._get_balance_property("payment_due")

    @property
    def payment_due_date(self) -> date | None:
        """Date of the next statement.

        Returns:
            date: the date on which the next payment is due
        """
        return self._get_balance_property("payment_due_date")

credit_limit: float | None property

Credit limit of the account.

Returns:

Name Type Description
float float | None

the credit limit available to the customer

last_statement_balance: float | None property

Balance of the account at the last statement date.

Returns:

Name Type Description
float float | None

the balance on the last statement

last_statement_date: date | None property

Date of the last statement.

Returns:

Name Type Description
date date | None

the date the last statement was issued on

payment_due: float | None property

Amount due on the next statement.

Returns:

Name Type Description
float float | None

the amount of any due payment

payment_due_date: date | None property

Date of the next statement.

Returns:

Name Type Description
date date | None

the date on which the next payment is due

CardJson

Bases: _TrueLayerBaseEntityJson

JSON representation of a Card.

Source code in wg_utilities/clients/truelayer.py
160
161
162
163
164
165
166
167
168
class CardJson(_TrueLayerBaseEntityJson):
    """JSON representation of a Card."""

    card_network: str
    card_type: str
    partial_card_number: str
    name_on_card: str
    valid_from: NotRequired[date]
    valid_to: NotRequired[date]

Transaction

Bases: BaseModelWithConfig

Class for individual transactions for data manipulation etc.

Source code in wg_utilities/clients/truelayer.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
class Transaction(BaseModelWithConfig):
    """Class for individual transactions for data manipulation etc."""

    amount: float
    currency: str
    description: str
    id: str = Field(alias="transaction_id")
    merchant_name: str | None = None
    meta: dict[str, str]
    normalised_provider_transaction_id: str | None = None
    provider_transaction_id: str | None
    running_balance: dict[str, str | float] | None = None
    timestamp: datetime
    transaction_category: TransactionCategory
    transaction_classification: list[str]
    transaction_type: str

    @field_validator("transaction_category", mode="before")
    @classmethod
    def validate_transaction_category(cls, v: str) -> TransactionCategory:
        """Validate the transaction category.

        The default Enum assignment doesn't work for some reason, so we have to do it
        here.

        This also helps to provide a meaningful error message if the category is
        invalid; Pydantic's doesn't include the invalid value unfortunately.
        """
        if v not in TransactionCategory.__members__:  # pragma: no cover
            raise ValueError(f"Invalid transaction category: {v}")

        return TransactionCategory[v]

    def __str__(self) -> str:
        """Return a string representation of the transaction."""
        return f"{self.description} | {self.amount} | {self.merchant_name}"

__str__()

Return a string representation of the transaction.

Source code in wg_utilities/clients/truelayer.py
447
448
449
def __str__(self) -> str:
    """Return a string representation of the transaction."""
    return f"{self.description} | {self.amount} | {self.merchant_name}"

validate_transaction_category(v) classmethod

Validate the transaction category.

The default Enum assignment doesn't work for some reason, so we have to do it here.

This also helps to provide a meaningful error message if the category is invalid; Pydantic's doesn't include the invalid value unfortunately.

Source code in wg_utilities/clients/truelayer.py
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
@field_validator("transaction_category", mode="before")
@classmethod
def validate_transaction_category(cls, v: str) -> TransactionCategory:
    """Validate the transaction category.

    The default Enum assignment doesn't work for some reason, so we have to do it
    here.

    This also helps to provide a meaningful error message if the category is
    invalid; Pydantic's doesn't include the invalid value unfortunately.
    """
    if v not in TransactionCategory.__members__:  # pragma: no cover
        raise ValueError(f"Invalid transaction category: {v}")

    return TransactionCategory[v]

TransactionCategory

Bases: Enum

Enum for TrueLayer transaction types.

init method is overridden to allow setting a description as well as the main value.

Source code in wg_utilities/clients/truelayer.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class TransactionCategory(Enum):
    """Enum for TrueLayer transaction types.

    __init__ method is overridden to allow setting a description as well as the main
    value.
    """

    ATM = (
        "ATM",
        "Deposit or withdrawal of funds using an ATM (Automated Teller Machine)",
    )
    BILL_PAYMENT = "Bill Payment", "Payment of a bill"
    CASH = (
        "Cash",
        "Cash deposited over the branch counter or using Cash and Deposit Machines",
    )
    CASHBACK = (
        "Cashback",
        "An option retailers offer to withdraw cash while making a debit card purchase",
    )
    CHEQUE = (
        "Cheque",
        "A document ordering the payment of money from a bank account to another person"
        " or organisation",
    )
    CORRECTION = "Correction", "Correction of a transaction error"
    CREDIT = "Credit", "Funds added to your account"
    DIRECT_DEBIT = (
        "Direct Debit",
        "An automatic withdrawal of funds initiated by a third party at regular"
        " intervals",
    )
    DIVIDEND = "Dividend", "A payment to your account from shares you hold"
    DEBIT = "Debit", "Funds taken out from your account, uncategorised by the bank"
    FEE_CHARGE = "Fee Charge", "Fees or charges in relation to a transaction"
    INTEREST = "Interest", "Credit or debit associated with interest earned or incurred"
    OTHER = "Other", "Miscellaneous credit or debit"
    PURCHASE = "Purchase", "A payment made with your debit or credit card"
    STANDING_ORDER = (
        "Standing Order",
        "A payment instructed by the account-holder to a third party at regular"
        " intervals",
    )
    TRANSFER = "Transfer", "Transfer of money between accounts"
    UNKNOWN = "Unknown", "No classification of transaction category known"

    def __init__(self, value: tuple[str, str], description: tuple[str, str]):
        self._value_ = value
        self.description = description

TrueLayerClient

Bases: OAuthClient[dict[Literal['results'], list[TrueLayerEntityJson]]]

Custom client for interacting with TrueLayer's APIs.

Source code in wg_utilities/clients/truelayer.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
class TrueLayerClient(OAuthClient[dict[Literal["results"], list[TrueLayerEntityJson]]]):
    """Custom client for interacting with TrueLayer's APIs."""

    AUTH_LINK_BASE = "https://auth.truelayer.com/"
    ACCESS_TOKEN_ENDPOINT = "https://auth.truelayer.com/connect/token"  # noqa: S105
    BASE_URL = "https://api.truelayer.com"

    DEFAULT_SCOPES: ClassVar[list[str]] = [
        "info",
        "accounts",
        "balance",
        "cards",
        "transactions",
        "direct_debits",
        "standing_orders",
        "offline_access",
    ]

    def __init__(  # noqa: PLR0913
        self,
        *,
        client_id: str,
        client_secret: str,
        log_requests: bool = False,
        creds_cache_path: Path | None = None,
        creds_cache_dir: Path | None = None,
        scopes: list[str] | None = None,
        oauth_login_redirect_host: str = "localhost",
        oauth_redirect_uri_override: str | None = None,
        headless_auth_link_callback: Callable[[str], None] | None = None,
        use_existing_credentials_only: bool = False,
        validate_request_success: bool = True,
        bank: Bank,
    ):
        super().__init__(
            base_url=self.BASE_URL,
            access_token_endpoint=self.ACCESS_TOKEN_ENDPOINT,
            auth_link_base=self.AUTH_LINK_BASE,
            client_id=client_id,
            client_secret=client_secret,
            log_requests=log_requests,
            creds_cache_path=creds_cache_path,
            creds_cache_dir=creds_cache_dir,
            scopes=scopes or self.DEFAULT_SCOPES,
            oauth_login_redirect_host=oauth_login_redirect_host,
            oauth_redirect_uri_override=oauth_redirect_uri_override,
            headless_auth_link_callback=headless_auth_link_callback,
            validate_request_success=validate_request_success,
            use_existing_credentials_only=use_existing_credentials_only,
        )

        self.bank = bank

    def _get_entity_by_id(
        self,
        entity_id: str,
        entity_class: type[AccountOrCard],
    ) -> AccountOrCard | None:
        """Get entity info based on a given ID.

        Args:
            entity_id (str): the unique ID for the account/card
            entity_class (type): the class to instantiate with the returned info

        Returns:
            Union([Account, Card]): a Card instance with associated info

        Raises:
            HTTPError: if a HTTPError is raised by the request, and it's not because
                the ID wasn't found
            ValueError: if >1 result is returned from the TrueLayer API
        """
        try:
            results = self.get_json_response(
                f"/data/v1/{entity_class.__name__.lower()}s/{entity_id}",
            ).get("results", [])
        except HTTPError as exc:
            if (
                exc.response is not None
                and exc.response.json().get("error") == "account_not_found"
            ):
                return None
            raise

        if len(results) != 1:
            raise ValueError(
                f"Unexpected number of results when getting {entity_class.__name__}:"
                f" {len(results)}",
            )

        return entity_class.from_json_response(results[0], truelayer_client=self)

    def _list_entities(self, entity_class: type[AccountOrCard]) -> list[AccountOrCard]:
        """List all accounts/cards under the given bank account.

        Args:
            entity_class (type): the class to instantiate with the returned info

        Returns:
            list[Union([Account, Card])]: a list of Account/Card instances with
                associated info

        Raises:
            HTTPError: if a HTTPError is raised by the `_get` method, but it's not a 501
        """
        try:
            res = self.get_json_response(f"/data/v1/{entity_class.__name__.lower()}s")
        except HTTPError as exc:
            if (
                exc.response is not None
                and exc.response.json().get("error") == "endpoint_not_supported"
            ):
                LOGGER.warning(
                    "{entity_class.__name__}s endpoint not supported by %s",
                    self.bank.value,
                )
                res = {}
            else:
                raise

        return [
            entity_class.from_json_response(result, truelayer_client=self)
            for result in res.get("results", [])
        ]

    def get_account_by_id(
        self,
        account_id: str,
    ) -> Account | None:
        """Get an Account instance based on the ID.

        Args:
            account_id (str): the ID of the card

        Returns:
            Account: an Account instance, with all relevant info
        """
        return self._get_entity_by_id(account_id, Account)

    def get_card_by_id(
        self,
        card_id: str,
    ) -> Card | None:
        """Get a Card instance based on the ID.

        Args:
            card_id (str): the ID of the card

        Returns:
            Card: a Card instance, with all relevant info
        """
        return self._get_entity_by_id(card_id, Card)

    def list_accounts(self) -> list[Account]:
        """List all accounts under the given bank account.

        Returns:
            list[Account]: Account instances, containing all related info
        """
        return self._list_entities(Account)

    def list_cards(self) -> list[Card]:
        """List all accounts under the given bank account.

        Returns:
            list[Account]: Account instances, containing all related info
        """
        return self._list_entities(Card)

    @property
    def _creds_rel_file_path(self) -> Path | None:
        """Get the credentials cache filepath relative to the cache directory.

        TrueLayer shares the same Client ID for all banks, so this overrides the default
        to separate credentials by bank.
        """

        try:
            client_id = self._client_id or self._credentials.client_id
        except AttributeError:
            return None

        return Path(type(self).__name__, client_id, f"{self.bank.name.lower()}.json")

get_account_by_id(account_id)

Get an Account instance based on the ID.

Parameters:

Name Type Description Default
account_id str

the ID of the card

required

Returns:

Name Type Description
Account Account | None

an Account instance, with all relevant info

Source code in wg_utilities/clients/truelayer.py
687
688
689
690
691
692
693
694
695
696
697
698
699
def get_account_by_id(
    self,
    account_id: str,
) -> Account | None:
    """Get an Account instance based on the ID.

    Args:
        account_id (str): the ID of the card

    Returns:
        Account: an Account instance, with all relevant info
    """
    return self._get_entity_by_id(account_id, Account)

get_card_by_id(card_id)

Get a Card instance based on the ID.

Parameters:

Name Type Description Default
card_id str

the ID of the card

required

Returns:

Name Type Description
Card Card | None

a Card instance, with all relevant info

Source code in wg_utilities/clients/truelayer.py
701
702
703
704
705
706
707
708
709
710
711
712
713
def get_card_by_id(
    self,
    card_id: str,
) -> Card | None:
    """Get a Card instance based on the ID.

    Args:
        card_id (str): the ID of the card

    Returns:
        Card: a Card instance, with all relevant info
    """
    return self._get_entity_by_id(card_id, Card)

list_accounts()

List all accounts under the given bank account.

Returns:

Type Description
list[Account]

list[Account]: Account instances, containing all related info

Source code in wg_utilities/clients/truelayer.py
715
716
717
718
719
720
721
def list_accounts(self) -> list[Account]:
    """List all accounts under the given bank account.

    Returns:
        list[Account]: Account instances, containing all related info
    """
    return self._list_entities(Account)

list_cards()

List all accounts under the given bank account.

Returns:

Type Description
list[Card]

list[Account]: Account instances, containing all related info

Source code in wg_utilities/clients/truelayer.py
723
724
725
726
727
728
729
def list_cards(self) -> list[Card]:
    """List all accounts under the given bank account.

    Returns:
        list[Account]: Account instances, containing all related info
    """
    return self._list_entities(Card)

TrueLayerEntity

Bases: BaseModelWithConfig

Parent class for all TrueLayer entities (accounts, cards, etc.).

Source code in wg_utilities/clients/truelayer.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
class TrueLayerEntity(BaseModelWithConfig):
    """Parent class for all TrueLayer entities (accounts, cards, etc.)."""

    BALANCE_FIELDS: ClassVar[Iterable[str]] = ()

    id: str = Field(alias="account_id")
    currency: str
    display_name: str
    provider: _TrueLayerEntityProvider
    update_timestamp: str

    _available_balance: float
    _current_balance: float

    truelayer_client: TrueLayerClient = Field(exclude=True)
    balance_update_threshold: timedelta = Field(timedelta(minutes=15), exclude=True)
    last_balance_update: datetime = Field(datetime(1970, 1, 1), exclude=True)
    _balance_variables: BalanceVariables

    @classmethod
    def from_json_response(
        cls,
        value: TrueLayerEntityJson,
        *,
        truelayer_client: TrueLayerClient,
    ) -> Self:
        """Create an account from a JSON response."""

        value_data: dict[str, Any] = {
            "truelayer_client": truelayer_client,
            **value,
        }

        return cls.model_validate(value_data)

    def get_transactions(
        self,
        from_datetime: datetime | None = None,
        to_datetime: datetime | None = None,
    ) -> list[Transaction]:
        """Get transactions for this entity.

        Polls the TL API to get all transactions under the given entity. If
        only one datetime parameter is provided, then the other is given a default
        value which maximises the range of results returned

        Args:
            from_datetime (datetime): lower range of transaction date range query
            to_datetime (datetime): upper range of transaction date range query

        Returns:
            list[Transaction]: one instance per tx, including all metadata etc.
        """

        if from_datetime or to_datetime:
            from_datetime = from_datetime or datetime.now(UTC) - timedelta(days=90)
            to_datetime = to_datetime or datetime.now(UTC)

            params: (
                dict[
                    StrBytIntFlt,
                    StrBytIntFlt | Iterable[StrBytIntFlt] | None,
                ]
                | None
            ) = {
                "from": from_datetime.isoformat(),
                "to": to_datetime.isoformat(),
            }
        else:
            params = None

        return [
            Transaction.model_validate(result)
            for result in self.truelayer_client.get_json_response(
                f"/data/v1/{self.__class__.__name__.lower()}s/{self.id}/transactions",
                params=params,
            ).get("results", [])
        ]

    def update_balance_values(self) -> None:
        """Update the balance-related instance attributes.

        Uses the latest values from the API. This is called automatically when
        the balance-related attributes are accessed (if the attribute is None or
        was updated more than `self.balance_update_threshold`minutes ago), but
        can also be called manually.
        """

        results = self.truelayer_client.get_json_response(
            f"/data/v1/{self.__class__.__name__.lower()}s/{self.id}/balance",
        ).get("results", [])

        if len(results) != 1:
            raise ValueError(
                "Unexpected number of results when getting balance info:"
                f" {len(results)}",
            )

        balance_result = results[0]

        for k, v in balance_result.items():
            if k in (
                "available",
                "current",
            ):
                attr_name = f"_{k}_balance"
            elif k.endswith("_date"):
                attr_name = f"_{k}"
                if isinstance(v, str):
                    v = datetime.strptime(  # noqa: PLW2901
                        v,
                        "%Y-%m-%dT%H:%M:%SZ",
                    ).date()
            else:
                attr_name = f"_{k}"

            if attr_name.lstrip("_") not in self.BALANCE_FIELDS:
                LOGGER.info("Skipping %s as it's not relevant for this entity type", k)
                continue

            LOGGER.info("Updating %s with value %s", attr_name, v)

            setattr(self, attr_name, v)

        self.last_balance_update = datetime.now(UTC)

    @overload
    def _get_balance_property(
        self,
        prop_name: Literal["current_balance"],
    ) -> float: ...

    @overload
    def _get_balance_property(
        self,
        prop_name: Literal[
            "available_balance",
            "overdraft",
            "credit_limit",
            "last_statement_balance",
            "payment_due",
        ],
    ) -> float | None: ...

    @overload
    def _get_balance_property(
        self,
        prop_name: Literal[
            "last_statement_date",
            "payment_due_date",
        ],
    ) -> date | None: ...

    def _get_balance_property(
        self,
        prop_name: Literal[
            "available_balance",
            "current_balance",
            "overdraft",
            "credit_limit",
            "last_statement_balance",
            "last_statement_date",
            "payment_due",
            "payment_due_date",
        ],
    ) -> float | date | None:
        """Get a value for a balance-specific property.

        Updates the values if necessary (i.e. if they don't already exist). This also
        has a check to see if property is relevant for the given entity type and if not
        it just returns None.

        Args:
            prop_name (str): the name of the property

        Returns:
            str: the value of the balance property
        """

        if prop_name not in self.BALANCE_FIELDS:
            return None

        if (
            not hasattr(self, f"_{prop_name}")
            or getattr(self, f"_{prop_name}") is None
            or self.last_balance_update
            <= (datetime.now(UTC) - self.balance_update_threshold)
        ):
            self.update_balance_values()

        return getattr(self, f"_{prop_name}", None)

    @property
    def available_balance(self) -> float | None:
        """Available balance for the entity.

        Returns:
            float: the amount of money available to the bank account holder
        """
        return self._get_balance_property("available_balance")

    @property
    def balance(self) -> float:
        """Get the available balance, or current if available is not available."""
        return self.available_balance or self.current_balance

    @property
    def current_balance(self) -> float:
        """Current balance of the account.

        Returns:
            float: the total amount of money in the account, including pending
                transactions
        """
        return self._get_balance_property("current_balance")

    def __str__(self) -> str:
        """Return a string representation of the entity."""
        return f"{self.display_name} | {self.provider.display_name}"

available_balance: float | None property

Available balance for the entity.

Returns:

Name Type Description
float float | None

the amount of money available to the bank account holder

balance: float property

Get the available balance, or current if available is not available.

current_balance: float property

Current balance of the account.

Returns:

Name Type Description
float float

the total amount of money in the account, including pending transactions

__str__()

Return a string representation of the entity.

Source code in wg_utilities/clients/truelayer.py
409
410
411
def __str__(self) -> str:
    """Return a string representation of the entity."""
    return f"{self.display_name} | {self.provider.display_name}"

from_json_response(value, *, truelayer_client) classmethod

Create an account from a JSON response.

Source code in wg_utilities/clients/truelayer.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@classmethod
def from_json_response(
    cls,
    value: TrueLayerEntityJson,
    *,
    truelayer_client: TrueLayerClient,
) -> Self:
    """Create an account from a JSON response."""

    value_data: dict[str, Any] = {
        "truelayer_client": truelayer_client,
        **value,
    }

    return cls.model_validate(value_data)

get_transactions(from_datetime=None, to_datetime=None)

Get transactions for this entity.

Polls the TL API to get all transactions under the given entity. If only one datetime parameter is provided, then the other is given a default value which maximises the range of results returned

Parameters:

Name Type Description Default
from_datetime datetime

lower range of transaction date range query

None
to_datetime datetime

upper range of transaction date range query

None

Returns:

Type Description
list[Transaction]

list[Transaction]: one instance per tx, including all metadata etc.

Source code in wg_utilities/clients/truelayer.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def get_transactions(
    self,
    from_datetime: datetime | None = None,
    to_datetime: datetime | None = None,
) -> list[Transaction]:
    """Get transactions for this entity.

    Polls the TL API to get all transactions under the given entity. If
    only one datetime parameter is provided, then the other is given a default
    value which maximises the range of results returned

    Args:
        from_datetime (datetime): lower range of transaction date range query
        to_datetime (datetime): upper range of transaction date range query

    Returns:
        list[Transaction]: one instance per tx, including all metadata etc.
    """

    if from_datetime or to_datetime:
        from_datetime = from_datetime or datetime.now(UTC) - timedelta(days=90)
        to_datetime = to_datetime or datetime.now(UTC)

        params: (
            dict[
                StrBytIntFlt,
                StrBytIntFlt | Iterable[StrBytIntFlt] | None,
            ]
            | None
        ) = {
            "from": from_datetime.isoformat(),
            "to": to_datetime.isoformat(),
        }
    else:
        params = None

    return [
        Transaction.model_validate(result)
        for result in self.truelayer_client.get_json_response(
            f"/data/v1/{self.__class__.__name__.lower()}s/{self.id}/transactions",
            params=params,
        ).get("results", [])
    ]

update_balance_values()

Update the balance-related instance attributes.

Uses the latest values from the API. This is called automatically when the balance-related attributes are accessed (if the attribute is None or was updated more than self.balance_update_thresholdminutes ago), but can also be called manually.

Source code in wg_utilities/clients/truelayer.py
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
def update_balance_values(self) -> None:
    """Update the balance-related instance attributes.

    Uses the latest values from the API. This is called automatically when
    the balance-related attributes are accessed (if the attribute is None or
    was updated more than `self.balance_update_threshold`minutes ago), but
    can also be called manually.
    """

    results = self.truelayer_client.get_json_response(
        f"/data/v1/{self.__class__.__name__.lower()}s/{self.id}/balance",
    ).get("results", [])

    if len(results) != 1:
        raise ValueError(
            "Unexpected number of results when getting balance info:"
            f" {len(results)}",
        )

    balance_result = results[0]

    for k, v in balance_result.items():
        if k in (
            "available",
            "current",
        ):
            attr_name = f"_{k}_balance"
        elif k.endswith("_date"):
            attr_name = f"_{k}"
            if isinstance(v, str):
                v = datetime.strptime(  # noqa: PLW2901
                    v,
                    "%Y-%m-%dT%H:%M:%SZ",
                ).date()
        else:
            attr_name = f"_{k}"

        if attr_name.lstrip("_") not in self.BALANCE_FIELDS:
            LOGGER.info("Skipping %s as it's not relevant for this entity type", k)
            continue

        LOGGER.info("Updating %s with value %s", attr_name, v)

        setattr(self, attr_name, v)

    self.last_balance_update = datetime.now(UTC)