diff --git a/src/arbiter/integrations/gitlab.py b/src/arbiter/integrations/gitlab.py new file mode 100644 index 0000000..0056a18 --- /dev/null +++ b/src/arbiter/integrations/gitlab.py @@ -0,0 +1,387 @@ +"""GitLab API client implementation.""" + +import logging +from datetime import datetime +from typing import Any, ClassVar +from urllib.parse import quote + +import httpx +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from arbiter.integrations.base import ( + Comment, + CommitStatus, + Platform, + PlatformClient, + PullRequestInfo, +) +from arbiter.integrations.exceptions import ( + AuthenticationError, + IntegrationError, + NotFoundError, + PlatformError, + RateLimitError, +) + +logger = logging.getLogger(__name__) + + +class GitLabClient(PlatformClient): + """GitLab API client for MR operations.""" + + # Status mapping from our enum to GitLab's expected values + STATUS_MAP: ClassVar[dict[CommitStatus, str]] = { + CommitStatus.PENDING: "pending", + CommitStatus.SUCCESS: "success", + CommitStatus.FAILURE: "failed", + CommitStatus.ERROR: "failed", + } + + def __init__( + self, + token: str, + base_url: str = "https://gitlab.com", + timeout: int = 30, + max_retries: int = 3, + ) -> None: + """Initialize GitLab client. + + Args: + token: GitLab personal access token or project token. + base_url: GitLab instance base URL. + timeout: Request timeout in seconds. + max_retries: Maximum retry attempts for transient errors. + """ + self.base_url = base_url.rstrip("/") + self.api_base = f"{self.base_url}/api/v4" + self.timeout = timeout + self.max_retries = max_retries + + self._client = httpx.AsyncClient( + base_url=self.api_base, + headers={ + "PRIVATE-TOKEN": token, + "Content-Type": "application/json", + }, + timeout=timeout, + ) + + @property + def platform(self) -> Platform: + return Platform.GITLAB + + def _encode_project_id(self, repository: str) -> str: + return quote(repository, safe="") + + def _handle_response_errors(self, response: httpx.Response, context: str) -> None: + """Check response for errors and raise appropriate exceptions. + + Args: + response: HTTP response to check. + context: Context string for error messages. + """ + if response.is_success: + return + + status = response.status_code + body = response.text + + if status == 401 or status == 403: + raise AuthenticationError( + f"GitLab authentication failed for {context}: {body}", + status_code=status, + response_body=body, + ) + + if status == 404: + raise NotFoundError( + f"GitLab resource not found for {context}: {body}", + status_code=status, + response_body=body, + ) + + if status == 429: + retry_after = response.headers.get("Retry-After") + retry_seconds = int(retry_after) if retry_after else None + raise RateLimitError( + f"GitLab rate limit exceeded for {context}", + retry_after=retry_seconds, + status_code=status, + response_body=body, + ) + + if status >= 500: + raise PlatformError( + f"GitLab server error for {context}: {body}", + status_code=status, + response_body=body, + ) + + raise IntegrationError( + f"GitLab API error for {context}: {body}", + status_code=status, + response_body=body, + ) + + def _check_rate_limit(self, response: httpx.Response) -> None: + """Log warning if rate limit is getting low. + + Args: + response: HTTP response with rate limit headers. + """ + remaining = response.headers.get("RateLimit-Remaining") + if remaining is not None: + remaining_int = int(remaining) + if remaining_int < 10: + limit = response.headers.get("RateLimit-Limit", "unknown") + logger.warning( + "GitLab rate limit low: %d/%s remaining", + remaining_int, + limit, + ) + + async def _request_with_retry( + self, + method: str, + url: str, + context: str, + **kwargs: Any, + ) -> httpx.Response: + """Make a request with automatic retry on transient errors. + + Args: + method: HTTP method. + url: Request URL. + context: Context string for error messages. + **kwargs: Additional request arguments. + + Returns: + HTTP response. + """ + + @retry( + retry=retry_if_exception_type((PlatformError, httpx.TransportError)), + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=1, max=10), + ) + async def _do_request() -> httpx.Response: + response = await self._client.request(method, url, **kwargs) + self._check_rate_limit(response) + self._handle_response_errors(response, context) + return response + + return await _do_request() + + async def get_pr_diff(self, repository: str, pr_number: int) -> str: + """Fetch the unified diff for a merge request. + + Args: + repository: Project path in owner/repo format. + pr_number: Merge request IID. + + Returns: + Unified diff as a string. + """ + project_id = self._encode_project_id(repository) + url = f"/projects/{project_id}/merge_requests/{pr_number}/diffs" + response = await self._request_with_retry( + "GET", + url, + f"get diff for {repository}!{pr_number}", + ) + diffs: list[dict[str, Any]] = response.json() + + # Convert GitLab diff format to unified diff + diff_parts: list[str] = [] + for diff_item in diffs: + old_path = diff_item.get("old_path", "") + new_path = diff_item.get("new_path", "") + diff_content = diff_item.get("diff", "") + + # Build unified diff header + if diff_item.get("new_file"): + diff_parts.append("--- /dev/null") + diff_parts.append(f"+++ b/{new_path}") + elif diff_item.get("deleted_file"): + diff_parts.append(f"--- a/{old_path}") + diff_parts.append("+++ /dev/null") + elif diff_item.get("renamed_file"): + diff_parts.append(f"--- a/{old_path}") + diff_parts.append(f"+++ b/{new_path}") + else: + diff_parts.append(f"--- a/{old_path}") + diff_parts.append(f"+++ b/{new_path}") + + diff_parts.append(diff_content) + + return "\n".join(diff_parts) + + async def post_comment(self, repository: str, pr_number: int, body: str) -> str: + """Post a note (comment) on a merge request. + + Args: + repository: Project path in owner/repo format. + pr_number: Merge request IID. + body: Note body in Markdown. + + Returns: + URL to the posted note. + """ + project_id = self._encode_project_id(repository) + url = f"/projects/{project_id}/merge_requests/{pr_number}/notes" + response = await self._request_with_retry( + "POST", + url, + f"post comment on {repository}!{pr_number}", + json={"body": body}, + ) + data: dict[str, Any] = response.json() + + # Build the note URL + note_id = data.get("id") + mr_url = f"{self.base_url}/{repository}/-/merge_requests/{pr_number}" + return f"{mr_url}#note_{note_id}" if note_id else mr_url + + async def update_commit_status( + self, + repository: str, + sha: str, + status: CommitStatus, + description: str, + context: str, + target_url: str | None = None, + ) -> None: + """Update the commit status (pipeline status). + + Args: + repository: Project path in owner/repo format. + sha: Commit SHA to update status for. + status: Status state. + description: Short description. + context: Status context/name. + target_url: Optional URL for more details. + """ + project_id = self._encode_project_id(repository) + url = f"/projects/{project_id}/statuses/{sha}" + payload: dict[str, Any] = { + "state": self.STATUS_MAP[status], + "description": description[:255], # GitLab limit + "name": context, + } + if target_url: + payload["target_url"] = target_url + + await self._request_with_retry( + "POST", + url, + f"update status for {repository}@{sha[:8]}", + json=payload, + ) + + async def get_pr_info(self, repository: str, pr_number: int) -> PullRequestInfo: + """Get information about a merge request. + + Args: + repository: Project path in owner/repo format. + pr_number: Merge request IID. + + Returns: + PullRequestInfo with MR metadata. + """ + project_id = self._encode_project_id(repository) + url = f"/projects/{project_id}/merge_requests/{pr_number}" + response = await self._request_with_retry( + "GET", + url, + f"get info for {repository}!{pr_number}", + ) + data: dict[str, Any] = response.json() + + return PullRequestInfo( + platform=Platform.GITLAB, + repository=repository, + pr_number=pr_number, + head_sha=data.get("sha", data.get("diff_refs", {}).get("head_sha", "")), + base_sha=data.get("diff_refs", {}).get("base_sha", ""), + head_ref=data["source_branch"], + base_ref=data["target_branch"], + title=data["title"], + author=data.get("author", {}).get("username"), + url=data["web_url"], + is_draft=data.get("work_in_progress", False) or data.get("draft", False), + ) + + async def get_comments(self, repository: str, pr_number: int) -> list[Comment]: + """Get notes (comments) on a merge request. + + Args: + repository: Project path in owner/repo format. + pr_number: Merge request IID. + + Returns: + List of notes on the MR. + """ + project_id = self._encode_project_id(repository) + url = f"/projects/{project_id}/merge_requests/{pr_number}/notes" + response = await self._request_with_retry( + "GET", + url, + f"get comments for {repository}!{pr_number}", + ) + data: list[dict[str, Any]] = response.json() + + comments: list[Comment] = [] + for item in data: + # Build note URL + note_id = item.get("id") + mr_url = f"{self.base_url}/{repository}/-/merge_requests/{pr_number}" + note_url = f"{mr_url}#note_{note_id}" if note_id else mr_url + + comments.append( + Comment( + id=str(item["id"]), + body=item.get("body", ""), + author=item.get("author", {}).get("username", ""), + url=note_url, + created_at=datetime.fromisoformat(item["created_at"].replace("Z", "+00:00")), + ) + ) + + return comments + + async def update_comment( + self, repository: str, pr_number: int, comment_id: str, body: str + ) -> str: + """Update an existing note on a merge request. + + Args: + repository: Project path in owner/repo format. + pr_number: Merge request IID. + comment_id: ID of the note to update. + body: New note body in Markdown. + + Returns: + URL to the updated note. + """ + project_id = self._encode_project_id(repository) + url = f"/projects/{project_id}/merge_requests/{pr_number}/notes/{comment_id}" + response = await self._request_with_retry( + "PUT", + url, + f"update note {comment_id} on {repository}!{pr_number}", + json={"body": body}, + ) + data: dict[str, Any] = response.json() + + # Build note URL + note_id = data.get("id", comment_id) + mr_url = f"{self.base_url}/{repository}/-/merge_requests/{pr_number}" + return f"{mr_url}#note_{note_id}" + + async def close(self) -> None: + await self._client.aclose()