diff --git a/src/arbiter/integrations/github.py b/src/arbiter/integrations/github.py new file mode 100644 index 0000000..10af91a --- /dev/null +++ b/src/arbiter/integrations/github.py @@ -0,0 +1,344 @@ +"""GitHub API client implementation.""" + +import logging +from datetime import datetime +from typing import Any, ClassVar + +import httpx +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from arbiter.integrations.base import ( + Comment, + CommitStatus, + Platform, + PlatformClient, + PullRequestInfo, +) +from arbiter.integrations.exceptions import ( + AuthenticationError, + IntegrationError, + NotFoundError, + PlatformError, + RateLimitError, +) + +logger = logging.getLogger(__name__) + + +class GitHubClient(PlatformClient): + """GitHub API client for PR operations.""" + + # GitHub API version header + API_VERSION = "2022-11-28" + + # Status mapping from our enum to GitHub's expected values + STATUS_MAP: ClassVar[dict[CommitStatus, str]] = { + CommitStatus.PENDING: "pending", + CommitStatus.SUCCESS: "success", + CommitStatus.FAILURE: "failure", + CommitStatus.ERROR: "error", + } + + def __init__( + self, + token: str, + base_url: str = "https://api.github.com", + timeout: int = 30, + max_retries: int = 3, + ) -> None: + """Initialize GitHub client. + + Args: + token: GitHub personal access token or app token. + base_url: GitHub API base URL (for enterprise instances). + timeout: Request timeout in seconds. + max_retries: Maximum retry attempts for transient errors. + """ + self.base_url = base_url.rstrip("/") + self.timeout = timeout + self.max_retries = max_retries + + self._client = httpx.AsyncClient( + base_url=self.base_url, + headers={ + "Authorization": f"Bearer {token}", + "Accept": "application/vnd.github+json", + "X-GitHub-Api-Version": self.API_VERSION, + }, + timeout=timeout, + ) + + @property + def platform(self) -> Platform: + return Platform.GITHUB + + def _handle_response_errors(self, response: httpx.Response, context: str) -> None: + """Check response for errors and raise appropriate exceptions. + + Args: + response: HTTP response to check. + context: Context string for error messages. + """ + if response.is_success: + return + + status = response.status_code + body = response.text + + if status == 401 or status == 403: + raise AuthenticationError( + f"GitHub authentication failed for {context}: {body}", + status_code=status, + response_body=body, + ) + + if status == 404: + raise NotFoundError( + f"GitHub resource not found for {context}: {body}", + status_code=status, + response_body=body, + ) + + if status == 429: + retry_after = response.headers.get("Retry-After") + retry_seconds = int(retry_after) if retry_after else None + raise RateLimitError( + f"GitHub rate limit exceeded for {context}", + retry_after=retry_seconds, + status_code=status, + response_body=body, + ) + + if status >= 500: + raise PlatformError( + f"GitHub server error for {context}: {body}", + status_code=status, + response_body=body, + ) + + raise IntegrationError( + f"GitHub API error for {context}: {body}", + status_code=status, + response_body=body, + ) + + def _check_rate_limit(self, response: httpx.Response) -> None: + """Log warning if rate limit is getting low. + + Args: + response: HTTP response with rate limit headers. + """ + remaining = response.headers.get("X-RateLimit-Remaining") + if remaining is not None: + remaining_int = int(remaining) + if remaining_int < 10: + limit = response.headers.get("X-RateLimit-Limit", "unknown") + logger.warning( + "GitHub rate limit low: %d/%s remaining", + remaining_int, + limit, + ) + + async def _request_with_retry( + self, + method: str, + url: str, + context: str, + **kwargs: Any, + ) -> httpx.Response: + """Make a request with automatic retry on transient errors. + + Args: + method: HTTP method. + url: Request URL. + context: Context string for error messages. + **kwargs: Additional request arguments. + + Returns: + HTTP response. + """ + + @retry( + retry=retry_if_exception_type((PlatformError, httpx.TransportError)), + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + async def _do_request() -> httpx.Response: + response = await self._client.request(method, url, **kwargs) + self._check_rate_limit(response) + self._handle_response_errors(response, context) + return response + + return await _do_request() + + async def get_pr_diff(self, repository: str, pr_number: int) -> str: + """Fetch the unified diff for a pull request. + + Args: + repository: Repository name in owner/repo format. + pr_number: Pull request number. + + Returns: + Unified diff as a string. + """ + url = f"/repos/{repository}/pulls/{pr_number}" + response = await self._request_with_retry( + "GET", + url, + f"get diff for {repository}#{pr_number}", + headers={"Accept": "application/vnd.github.v3.diff"}, + ) + return response.text + + async def post_comment(self, repository: str, pr_number: int, body: str) -> str: + """Post a comment on a pull request. + + Args: + repository: Repository name in owner/repo format. + pr_number: Pull request number. + body: Comment body in Markdown. + + Returns: + URL to the posted comment. + """ + url = f"/repos/{repository}/issues/{pr_number}/comments" + response = await self._request_with_retry( + "POST", + url, + f"post comment on {repository}#{pr_number}", + json={"body": body}, + ) + data: dict[str, Any] = response.json() + return str(data.get("html_url", "")) + + async def update_commit_status( + self, + repository: str, + sha: str, + status: CommitStatus, + description: str, + context: str, + target_url: str | None = None, + ) -> None: + """Update the commit status. + + Args: + repository: Repository name in owner/repo format. + sha: Commit SHA to update status for. + status: Status state. + description: Short description (max 140 chars). + context: Status context/name. + target_url: Optional URL for more details. + """ + url = f"/repos/{repository}/statuses/{sha}" + payload: dict[str, Any] = { + "state": self.STATUS_MAP[status], + "description": description[:140], # GitHub limit + "context": context, + } + if target_url: + payload["target_url"] = target_url + + await self._request_with_retry( + "POST", + url, + f"update status for {repository}@{sha[:8]}", + json=payload, + ) + + async def get_pr_info(self, repository: str, pr_number: int) -> PullRequestInfo: + """Get information about a pull request. + + Args: + repository: Repository name in owner/repo format. + pr_number: Pull request number. + + Returns: + PullRequestInfo with PR metadata. + """ + url = f"/repos/{repository}/pulls/{pr_number}" + response = await self._request_with_retry( + "GET", + url, + f"get info for {repository}#{pr_number}", + ) + data: dict[str, Any] = response.json() + + return PullRequestInfo( + platform=Platform.GITHUB, + repository=repository, + pr_number=pr_number, + head_sha=data["head"]["sha"], + base_sha=data["base"]["sha"], + head_ref=data["head"]["ref"], + base_ref=data["base"]["ref"], + title=data["title"], + author=data.get("user", {}).get("login"), + url=data["html_url"], + is_draft=data.get("draft", False), + ) + + async def get_comments(self, repository: str, pr_number: int) -> list[Comment]: + """Get comments on a pull request. + + Args: + repository: Repository name in owner/repo format. + pr_number: Pull request number. + + Returns: + List of comments on the PR. + """ + url = f"/repos/{repository}/issues/{pr_number}/comments" + response = await self._request_with_retry( + "GET", + url, + f"get comments for {repository}#{pr_number}", + ) + data: list[dict[str, Any]] = response.json() + + comments: list[Comment] = [] + for item in data: + comments.append( + Comment( + id=str(item["id"]), + body=item.get("body", ""), + author=item.get("user", {}).get("login", ""), + url=item.get("html_url", ""), + created_at=datetime.fromisoformat(item["created_at"].replace("Z", "+00:00")), + ) + ) + + return comments + + async def update_comment( + self, repository: str, pr_number: int, comment_id: str, body: str + ) -> str: + """Update an existing comment. + + Args: + repository: Repository name in owner/repo format. + pr_number: Pull request number (unused for GitHub, but required by interface). + comment_id: ID of the comment to update. + body: New comment body in Markdown. + + Returns: + URL to the updated comment. + """ + _ = pr_number # GitHub doesn't need pr_number for comment updates + url = f"/repos/{repository}/issues/comments/{comment_id}" + response = await self._request_with_retry( + "PATCH", + url, + f"update comment {comment_id} on {repository}", + json={"body": body}, + ) + data: dict[str, Any] = response.json() + return str(data.get("html_url", "")) + + async def close(self) -> None: + await self._client.aclose()