Source code for redditpythonapi.reddit

"""
Wrapper for Reddit API.
Using a wrapper simplifies accessing the API, mostly due to handling OAuth.
"""

from enum import StrEnum
from logging import getLogger
from time import time_ns
from typing import Any

from httpx import AsyncClient, BasicAuth, Response

Article = dict[str, Any]


[docs] class ArticlesSortType(StrEnum): """Enum with all viable sorting types""" HOT = "hot" NEW = "new" RISING = "rising" TOP = "top" CONTROVERSIAL = "controversial"
[docs] class ArticlesSortTime(StrEnum): """Enum with all viable sort times""" HOUR = "hour" DAY = "day" WEEK = "week" MONTH = "month" YEAR = "year" ALL = "all"
[docs] class Reddit: """Class wrapping Reddit API Class wrapping calls to Reddit API. Handles all necessary URLs, parameters, headers, etc. Also handles requesting new OAuth 2.0 access tokens and authorization in general. Args: client_id (str): Reddit app client ID to use for authorization client_secret (str): Reddit app client secret to use for authorization user_agent (str): user agent used in all Reddit API requests """ _ACCESS_TOKEN_URL = "https://www.reddit.com/api/v1/access_token" _SUBREDDIT_ARTICLES_URL = "https://oauth.reddit.com/r/{subreddit}/{sort}" _USER_ARTICLES_URL = "https://oauth.reddit.com/user/{user}/submitted" _AUTH_EXPIRY_OVERHEAD_NS = 60_000_000_000 def __init__( self, client_id: str, client_secret: str, user_agent: str = "Reddit Python API (by Electronic-Mango on GitHub)", ) -> None: self._client_auth = BasicAuth(username=client_id, password=client_secret) self._auth_headers = {"User-agent": user_agent} self._access_token_expires_in = 0 self._logger = getLogger(__name__)
[docs] async def subreddit_articles( self, subreddit: str, sort: ArticlesSortType | None = None, time: ArticlesSortTime | None = None, limit: int | None = None, ) -> list[Article]: """Get a list of Reddit articles from the given subreddit Args: subreddit (str): subreddit to load articles from sort (ArticlesSortType | None): sort type to use when loading articles, "hot" by default time (ArticlesSortTime | None): sort time to use when loading articles, by default not passed onto Reddit API limit (int | None): up to how many articles should be loaded, by default not passed onto Reddit API Returns: list[Article]: list of loaded articles from the given subreddit """ self._logger.info(f"Loading subreddit articles [{subreddit}] [{sort}] [{time}] [{limit}]") sort = sort or ArticlesSortType.HOT url = self._SUBREDDIT_ARTICLES_URL.format(subreddit=subreddit, sort=sort.value) params = self._prepare_params(limit=limit, time=time) return await self._get_articles(url, params)
[docs] async def user_articles( self, user: str, sort: ArticlesSortType | None = None, time: ArticlesSortTime | None = None, limit: int | None = None, ) -> list[Article]: """Get a list of Reddit articles from the given Reddit user Args: user (str): Reddit user to load articles from sort (ArticlesSortType | None): sort type to use when loading articles, by default not passed onto Reddit API time (ArticlesSortTime | None): sort time to use when loading articles, by default not passed onto Reddit API limit (int | None): up to how many articles should be loaded, by default not passed onto Reddit API Returns: list[Article]: list of loaded articles from the Reddit user """ self._logger.info(f"Loading user articles [{user}] [{sort}] [{time}] [{limit}]") url = self._USER_ARTICLES_URL.format(user=user) params = self._prepare_params(limit=limit, sort=sort, time=time) return await self._get_articles(url, params)
async def _authorize(self) -> None: self._logger.info("Authorizing") response = await self._request_access_token() response.raise_for_status() response_content = response.json() access_token = response_content["access_token"] self._auth_headers["Authorization"] = f"Bearer {access_token}" expires_in = response_content["expires_in"] * 1_000_000_000 self._access_token_expires_in = time_ns() + expires_in - self._AUTH_EXPIRY_OVERHEAD_NS async def _request_access_token(self) -> Response: async with AsyncClient() as client: return await client.post( url=self._ACCESS_TOKEN_URL, params={"grant_type": "client_credentials"}, auth=self._client_auth, headers=self._auth_headers, ) def _prepare_params( self, sort: ArticlesSortType | None = None, time: ArticlesSortTime | None = None, limit: int | None = None, ): return { **({"sort": sort.value} if sort is not None else {}), **({"t": time.value} if time is not None else {}), **({"limit": limit} if limit is not None else {}), } async def _get_articles(self, url: str, params: dict[str, Any]) -> list[Article]: if self._access_token_expires_in <= time_ns(): self._logger.info("Access token expired, requesting new one") await self._authorize() response = await self._request_articles(url, params) if response.status_code in [401, 403]: self._logger.info(f"Response returned code [{response.status_code}], re-authorizing") await self._authorize() response = await self._request_articles(url, params) response.raise_for_status() return [article["data"] for article in response.json()["data"]["children"]] async def _request_articles(self, url: str, params: dict[str, Any]) -> Response: async with AsyncClient() as client: return await client.get(url=url, params=params, headers=self._auth_headers)