add github client
This commit is contained in:
344
src/arbiter/integrations/github.py
Normal file
344
src/arbiter/integrations/github.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
"""GitHub API client implementation."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
from arbiter.integrations.base import (
|
||||||
|
Comment,
|
||||||
|
CommitStatus,
|
||||||
|
Platform,
|
||||||
|
PlatformClient,
|
||||||
|
PullRequestInfo,
|
||||||
|
)
|
||||||
|
from arbiter.integrations.exceptions import (
|
||||||
|
AuthenticationError,
|
||||||
|
IntegrationError,
|
||||||
|
NotFoundError,
|
||||||
|
PlatformError,
|
||||||
|
RateLimitError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubClient(PlatformClient):
|
||||||
|
"""GitHub API client for PR operations."""
|
||||||
|
|
||||||
|
# GitHub API version header
|
||||||
|
API_VERSION = "2022-11-28"
|
||||||
|
|
||||||
|
# Status mapping from our enum to GitHub's expected values
|
||||||
|
STATUS_MAP: ClassVar[dict[CommitStatus, str]] = {
|
||||||
|
CommitStatus.PENDING: "pending",
|
||||||
|
CommitStatus.SUCCESS: "success",
|
||||||
|
CommitStatus.FAILURE: "failure",
|
||||||
|
CommitStatus.ERROR: "error",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
token: str,
|
||||||
|
base_url: str = "https://api.github.com",
|
||||||
|
timeout: int = 30,
|
||||||
|
max_retries: int = 3,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize GitHub client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: GitHub personal access token or app token.
|
||||||
|
base_url: GitHub API base URL (for enterprise instances).
|
||||||
|
timeout: Request timeout in seconds.
|
||||||
|
max_retries: Maximum retry attempts for transient errors.
|
||||||
|
"""
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.timeout = timeout
|
||||||
|
self.max_retries = max_retries
|
||||||
|
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
base_url=self.base_url,
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {token}",
|
||||||
|
"Accept": "application/vnd.github+json",
|
||||||
|
"X-GitHub-Api-Version": self.API_VERSION,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def platform(self) -> Platform:
|
||||||
|
return Platform.GITHUB
|
||||||
|
|
||||||
|
def _handle_response_errors(self, response: httpx.Response, context: str) -> None:
|
||||||
|
"""Check response for errors and raise appropriate exceptions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: HTTP response to check.
|
||||||
|
context: Context string for error messages.
|
||||||
|
"""
|
||||||
|
if response.is_success:
|
||||||
|
return
|
||||||
|
|
||||||
|
status = response.status_code
|
||||||
|
body = response.text
|
||||||
|
|
||||||
|
if status == 401 or status == 403:
|
||||||
|
raise AuthenticationError(
|
||||||
|
f"GitHub authentication failed for {context}: {body}",
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status == 404:
|
||||||
|
raise NotFoundError(
|
||||||
|
f"GitHub resource not found for {context}: {body}",
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status == 429:
|
||||||
|
retry_after = response.headers.get("Retry-After")
|
||||||
|
retry_seconds = int(retry_after) if retry_after else None
|
||||||
|
raise RateLimitError(
|
||||||
|
f"GitHub rate limit exceeded for {context}",
|
||||||
|
retry_after=retry_seconds,
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
if status >= 500:
|
||||||
|
raise PlatformError(
|
||||||
|
f"GitHub server error for {context}: {body}",
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise IntegrationError(
|
||||||
|
f"GitHub API error for {context}: {body}",
|
||||||
|
status_code=status,
|
||||||
|
response_body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_rate_limit(self, response: httpx.Response) -> None:
|
||||||
|
"""Log warning if rate limit is getting low.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: HTTP response with rate limit headers.
|
||||||
|
"""
|
||||||
|
remaining = response.headers.get("X-RateLimit-Remaining")
|
||||||
|
if remaining is not None:
|
||||||
|
remaining_int = int(remaining)
|
||||||
|
if remaining_int < 10:
|
||||||
|
limit = response.headers.get("X-RateLimit-Limit", "unknown")
|
||||||
|
logger.warning(
|
||||||
|
"GitHub rate limit low: %d/%s remaining",
|
||||||
|
remaining_int,
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _request_with_retry(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
context: str,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> httpx.Response:
|
||||||
|
"""Make a request with automatic retry on transient errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method.
|
||||||
|
url: Request URL.
|
||||||
|
context: Context string for error messages.
|
||||||
|
**kwargs: Additional request arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTTP response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type((PlatformError, httpx.TransportError)),
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=10),
|
||||||
|
)
|
||||||
|
async def _do_request() -> httpx.Response:
|
||||||
|
response = await self._client.request(method, url, **kwargs)
|
||||||
|
self._check_rate_limit(response)
|
||||||
|
self._handle_response_errors(response, context)
|
||||||
|
return response
|
||||||
|
|
||||||
|
return await _do_request()
|
||||||
|
|
||||||
|
async def get_pr_diff(self, repository: str, pr_number: int) -> str:
|
||||||
|
"""Fetch the unified diff for a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name in owner/repo format.
|
||||||
|
pr_number: Pull request number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Unified diff as a string.
|
||||||
|
"""
|
||||||
|
url = f"/repos/{repository}/pulls/{pr_number}"
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
f"get diff for {repository}#{pr_number}",
|
||||||
|
headers={"Accept": "application/vnd.github.v3.diff"},
|
||||||
|
)
|
||||||
|
return response.text
|
||||||
|
|
||||||
|
async def post_comment(self, repository: str, pr_number: int, body: str) -> str:
|
||||||
|
"""Post a comment on a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name in owner/repo format.
|
||||||
|
pr_number: Pull request number.
|
||||||
|
body: Comment body in Markdown.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL to the posted comment.
|
||||||
|
"""
|
||||||
|
url = f"/repos/{repository}/issues/{pr_number}/comments"
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
f"post comment on {repository}#{pr_number}",
|
||||||
|
json={"body": body},
|
||||||
|
)
|
||||||
|
data: dict[str, Any] = response.json()
|
||||||
|
return str(data.get("html_url", ""))
|
||||||
|
|
||||||
|
async def update_commit_status(
|
||||||
|
self,
|
||||||
|
repository: str,
|
||||||
|
sha: str,
|
||||||
|
status: CommitStatus,
|
||||||
|
description: str,
|
||||||
|
context: str,
|
||||||
|
target_url: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Update the commit status.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name in owner/repo format.
|
||||||
|
sha: Commit SHA to update status for.
|
||||||
|
status: Status state.
|
||||||
|
description: Short description (max 140 chars).
|
||||||
|
context: Status context/name.
|
||||||
|
target_url: Optional URL for more details.
|
||||||
|
"""
|
||||||
|
url = f"/repos/{repository}/statuses/{sha}"
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"state": self.STATUS_MAP[status],
|
||||||
|
"description": description[:140], # GitHub limit
|
||||||
|
"context": context,
|
||||||
|
}
|
||||||
|
if target_url:
|
||||||
|
payload["target_url"] = target_url
|
||||||
|
|
||||||
|
await self._request_with_retry(
|
||||||
|
"POST",
|
||||||
|
url,
|
||||||
|
f"update status for {repository}@{sha[:8]}",
|
||||||
|
json=payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_pr_info(self, repository: str, pr_number: int) -> PullRequestInfo:
|
||||||
|
"""Get information about a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name in owner/repo format.
|
||||||
|
pr_number: Pull request number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PullRequestInfo with PR metadata.
|
||||||
|
"""
|
||||||
|
url = f"/repos/{repository}/pulls/{pr_number}"
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
f"get info for {repository}#{pr_number}",
|
||||||
|
)
|
||||||
|
data: dict[str, Any] = response.json()
|
||||||
|
|
||||||
|
return PullRequestInfo(
|
||||||
|
platform=Platform.GITHUB,
|
||||||
|
repository=repository,
|
||||||
|
pr_number=pr_number,
|
||||||
|
head_sha=data["head"]["sha"],
|
||||||
|
base_sha=data["base"]["sha"],
|
||||||
|
head_ref=data["head"]["ref"],
|
||||||
|
base_ref=data["base"]["ref"],
|
||||||
|
title=data["title"],
|
||||||
|
author=data.get("user", {}).get("login"),
|
||||||
|
url=data["html_url"],
|
||||||
|
is_draft=data.get("draft", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_comments(self, repository: str, pr_number: int) -> list[Comment]:
|
||||||
|
"""Get comments on a pull request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name in owner/repo format.
|
||||||
|
pr_number: Pull request number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of comments on the PR.
|
||||||
|
"""
|
||||||
|
url = f"/repos/{repository}/issues/{pr_number}/comments"
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"GET",
|
||||||
|
url,
|
||||||
|
f"get comments for {repository}#{pr_number}",
|
||||||
|
)
|
||||||
|
data: list[dict[str, Any]] = response.json()
|
||||||
|
|
||||||
|
comments: list[Comment] = []
|
||||||
|
for item in data:
|
||||||
|
comments.append(
|
||||||
|
Comment(
|
||||||
|
id=str(item["id"]),
|
||||||
|
body=item.get("body", ""),
|
||||||
|
author=item.get("user", {}).get("login", ""),
|
||||||
|
url=item.get("html_url", ""),
|
||||||
|
created_at=datetime.fromisoformat(item["created_at"].replace("Z", "+00:00")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return comments
|
||||||
|
|
||||||
|
async def update_comment(
|
||||||
|
self, repository: str, pr_number: int, comment_id: str, body: str
|
||||||
|
) -> str:
|
||||||
|
"""Update an existing comment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository: Repository name in owner/repo format.
|
||||||
|
pr_number: Pull request number (unused for GitHub, but required by interface).
|
||||||
|
comment_id: ID of the comment to update.
|
||||||
|
body: New comment body in Markdown.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
URL to the updated comment.
|
||||||
|
"""
|
||||||
|
_ = pr_number # GitHub doesn't need pr_number for comment updates
|
||||||
|
url = f"/repos/{repository}/issues/comments/{comment_id}"
|
||||||
|
response = await self._request_with_retry(
|
||||||
|
"PATCH",
|
||||||
|
url,
|
||||||
|
f"update comment {comment_id} on {repository}",
|
||||||
|
json={"body": body},
|
||||||
|
)
|
||||||
|
data: dict[str, Any] = response.json()
|
||||||
|
return str(data.get("html_url", ""))
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
await self._client.aclose()
|
||||||
Reference in New Issue
Block a user