add github client
This commit is contained in:
344
src/arbiter/integrations/github.py
Normal file
344
src/arbiter/integrations/github.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""GitHub API client implementation."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import httpx
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from arbiter.integrations.base import (
|
||||
Comment,
|
||||
CommitStatus,
|
||||
Platform,
|
||||
PlatformClient,
|
||||
PullRequestInfo,
|
||||
)
|
||||
from arbiter.integrations.exceptions import (
|
||||
AuthenticationError,
|
||||
IntegrationError,
|
||||
NotFoundError,
|
||||
PlatformError,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GitHubClient(PlatformClient):
|
||||
"""GitHub API client for PR operations."""
|
||||
|
||||
# GitHub API version header
|
||||
API_VERSION = "2022-11-28"
|
||||
|
||||
# Status mapping from our enum to GitHub's expected values
|
||||
STATUS_MAP: ClassVar[dict[CommitStatus, str]] = {
|
||||
CommitStatus.PENDING: "pending",
|
||||
CommitStatus.SUCCESS: "success",
|
||||
CommitStatus.FAILURE: "failure",
|
||||
CommitStatus.ERROR: "error",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
base_url: str = "https://api.github.com",
|
||||
timeout: int = 30,
|
||||
max_retries: int = 3,
|
||||
) -> None:
|
||||
"""Initialize GitHub client.
|
||||
|
||||
Args:
|
||||
token: GitHub personal access token or app token.
|
||||
base_url: GitHub API base URL (for enterprise instances).
|
||||
timeout: Request timeout in seconds.
|
||||
max_retries: Maximum retry attempts for transient errors.
|
||||
"""
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Accept": "application/vnd.github+json",
|
||||
"X-GitHub-Api-Version": self.API_VERSION,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@property
|
||||
def platform(self) -> Platform:
|
||||
return Platform.GITHUB
|
||||
|
||||
def _handle_response_errors(self, response: httpx.Response, context: str) -> None:
|
||||
"""Check response for errors and raise appropriate exceptions.
|
||||
|
||||
Args:
|
||||
response: HTTP response to check.
|
||||
context: Context string for error messages.
|
||||
"""
|
||||
if response.is_success:
|
||||
return
|
||||
|
||||
status = response.status_code
|
||||
body = response.text
|
||||
|
||||
if status == 401 or status == 403:
|
||||
raise AuthenticationError(
|
||||
f"GitHub authentication failed for {context}: {body}",
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status == 404:
|
||||
raise NotFoundError(
|
||||
f"GitHub resource not found for {context}: {body}",
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
retry_seconds = int(retry_after) if retry_after else None
|
||||
raise RateLimitError(
|
||||
f"GitHub rate limit exceeded for {context}",
|
||||
retry_after=retry_seconds,
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
if status >= 500:
|
||||
raise PlatformError(
|
||||
f"GitHub server error for {context}: {body}",
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
raise IntegrationError(
|
||||
f"GitHub API error for {context}: {body}",
|
||||
status_code=status,
|
||||
response_body=body,
|
||||
)
|
||||
|
||||
def _check_rate_limit(self, response: httpx.Response) -> None:
|
||||
"""Log warning if rate limit is getting low.
|
||||
|
||||
Args:
|
||||
response: HTTP response with rate limit headers.
|
||||
"""
|
||||
remaining = response.headers.get("X-RateLimit-Remaining")
|
||||
if remaining is not None:
|
||||
remaining_int = int(remaining)
|
||||
if remaining_int < 10:
|
||||
limit = response.headers.get("X-RateLimit-Limit", "unknown")
|
||||
logger.warning(
|
||||
"GitHub rate limit low: %d/%s remaining",
|
||||
remaining_int,
|
||||
limit,
|
||||
)
|
||||
|
||||
async def _request_with_retry(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
context: str,
|
||||
**kwargs: Any,
|
||||
) -> httpx.Response:
|
||||
"""Make a request with automatic retry on transient errors.
|
||||
|
||||
Args:
|
||||
method: HTTP method.
|
||||
url: Request URL.
|
||||
context: Context string for error messages.
|
||||
**kwargs: Additional request arguments.
|
||||
|
||||
Returns:
|
||||
HTTP response.
|
||||
"""
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type((PlatformError, httpx.TransportError)),
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||
)
|
||||
async def _do_request() -> httpx.Response:
|
||||
response = await self._client.request(method, url, **kwargs)
|
||||
self._check_rate_limit(response)
|
||||
self._handle_response_errors(response, context)
|
||||
return response
|
||||
|
||||
return await _do_request()
|
||||
|
||||
async def get_pr_diff(self, repository: str, pr_number: int) -> str:
|
||||
"""Fetch the unified diff for a pull request.
|
||||
|
||||
Args:
|
||||
repository: Repository name in owner/repo format.
|
||||
pr_number: Pull request number.
|
||||
|
||||
Returns:
|
||||
Unified diff as a string.
|
||||
"""
|
||||
url = f"/repos/{repository}/pulls/{pr_number}"
|
||||
response = await self._request_with_retry(
|
||||
"GET",
|
||||
url,
|
||||
f"get diff for {repository}#{pr_number}",
|
||||
headers={"Accept": "application/vnd.github.v3.diff"},
|
||||
)
|
||||
return response.text
|
||||
|
||||
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
||||
"""Post a comment on a pull request.
|
||||
|
||||
Args:
|
||||
repository: Repository name in owner/repo format.
|
||||
pr_number: Pull request number.
|
||||
body: Comment body in Markdown.
|
||||
|
||||
Returns:
|
||||
URL to the posted comment.
|
||||
"""
|
||||
url = f"/repos/{repository}/issues/{pr_number}/comments"
|
||||
response = await self._request_with_retry(
|
||||
"POST",
|
||||
url,
|
||||
f"post comment on {repository}#{pr_number}",
|
||||
json={"body": body},
|
||||
)
|
||||
data: dict[str, Any] = response.json()
|
||||
return str(data.get("html_url", ""))
|
||||
|
||||
async def update_commit_status(
|
||||
self,
|
||||
repository: str,
|
||||
sha: str,
|
||||
status: CommitStatus,
|
||||
description: str,
|
||||
context: str,
|
||||
target_url: str | None = None,
|
||||
) -> None:
|
||||
"""Update the commit status.
|
||||
|
||||
Args:
|
||||
repository: Repository name in owner/repo format.
|
||||
sha: Commit SHA to update status for.
|
||||
status: Status state.
|
||||
description: Short description (max 140 chars).
|
||||
context: Status context/name.
|
||||
target_url: Optional URL for more details.
|
||||
"""
|
||||
url = f"/repos/{repository}/statuses/{sha}"
|
||||
payload: dict[str, Any] = {
|
||||
"state": self.STATUS_MAP[status],
|
||||
"description": description[:140], # GitHub limit
|
||||
"context": context,
|
||||
}
|
||||
if target_url:
|
||||
payload["target_url"] = target_url
|
||||
|
||||
await self._request_with_retry(
|
||||
"POST",
|
||||
url,
|
||||
f"update status for {repository}@{sha[:8]}",
|
||||
json=payload,
|
||||
)
|
||||
|
||||
async def get_pr_info(self, repository: str, pr_number: int) -> PullRequestInfo:
|
||||
"""Get information about a pull request.
|
||||
|
||||
Args:
|
||||
repository: Repository name in owner/repo format.
|
||||
pr_number: Pull request number.
|
||||
|
||||
Returns:
|
||||
PullRequestInfo with PR metadata.
|
||||
"""
|
||||
url = f"/repos/{repository}/pulls/{pr_number}"
|
||||
response = await self._request_with_retry(
|
||||
"GET",
|
||||
url,
|
||||
f"get info for {repository}#{pr_number}",
|
||||
)
|
||||
data: dict[str, Any] = response.json()
|
||||
|
||||
return PullRequestInfo(
|
||||
platform=Platform.GITHUB,
|
||||
repository=repository,
|
||||
pr_number=pr_number,
|
||||
head_sha=data["head"]["sha"],
|
||||
base_sha=data["base"]["sha"],
|
||||
head_ref=data["head"]["ref"],
|
||||
base_ref=data["base"]["ref"],
|
||||
title=data["title"],
|
||||
author=data.get("user", {}).get("login"),
|
||||
url=data["html_url"],
|
||||
is_draft=data.get("draft", False),
|
||||
)
|
||||
|
||||
async def get_comments(self, repository: str, pr_number: int) -> list[Comment]:
|
||||
"""Get comments on a pull request.
|
||||
|
||||
Args:
|
||||
repository: Repository name in owner/repo format.
|
||||
pr_number: Pull request number.
|
||||
|
||||
Returns:
|
||||
List of comments on the PR.
|
||||
"""
|
||||
url = f"/repos/{repository}/issues/{pr_number}/comments"
|
||||
response = await self._request_with_retry(
|
||||
"GET",
|
||||
url,
|
||||
f"get comments for {repository}#{pr_number}",
|
||||
)
|
||||
data: list[dict[str, Any]] = response.json()
|
||||
|
||||
comments: list[Comment] = []
|
||||
for item in data:
|
||||
comments.append(
|
||||
Comment(
|
||||
id=str(item["id"]),
|
||||
body=item.get("body", ""),
|
||||
author=item.get("user", {}).get("login", ""),
|
||||
url=item.get("html_url", ""),
|
||||
created_at=datetime.fromisoformat(item["created_at"].replace("Z", "+00:00")),
|
||||
)
|
||||
)
|
||||
|
||||
return comments
|
||||
|
||||
async def update_comment(
|
||||
self, repository: str, pr_number: int, comment_id: str, body: str
|
||||
) -> str:
|
||||
"""Update an existing comment.
|
||||
|
||||
Args:
|
||||
repository: Repository name in owner/repo format.
|
||||
pr_number: Pull request number (unused for GitHub, but required by interface).
|
||||
comment_id: ID of the comment to update.
|
||||
body: New comment body in Markdown.
|
||||
|
||||
Returns:
|
||||
URL to the updated comment.
|
||||
"""
|
||||
_ = pr_number # GitHub doesn't need pr_number for comment updates
|
||||
url = f"/repos/{repository}/issues/comments/{comment_id}"
|
||||
response = await self._request_with_retry(
|
||||
"PATCH",
|
||||
url,
|
||||
f"update comment {comment_id} on {repository}",
|
||||
json={"body": body},
|
||||
)
|
||||
data: dict[str, Any] = response.json()
|
||||
return str(data.get("html_url", ""))
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
Reference in New Issue
Block a user