just why not
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
All checks were successful
AI Codebase Quality Review / ai-codebase-review (push) Successful in 39s
This commit is contained in:
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
420
tools/ai-review/clients/providers/azure_provider.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""Azure OpenAI Provider
|
||||
|
||||
Integration with Azure OpenAI Service for enterprise deployments.
|
||||
Supports custom deployments, regional endpoints, and Azure AD auth.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from clients.llm_client import BaseLLMProvider, LLMResponse, ToolCall
|
||||
|
||||
|
||||
class AzureOpenAIProvider(BaseLLMProvider):
|
||||
"""Azure OpenAI Service provider.
|
||||
|
||||
Provides integration with Azure-hosted OpenAI models for
|
||||
enterprise customers with Azure deployments.
|
||||
|
||||
Supports:
|
||||
- GPT-4, GPT-4 Turbo, GPT-4o
|
||||
- GPT-3.5 Turbo
|
||||
- Custom fine-tuned models
|
||||
|
||||
Environment Variables:
|
||||
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||
- AZURE_OPENAI_API_KEY: API key for authentication
|
||||
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||
- AZURE_OPENAI_API_VERSION: API version (default: 2024-02-15-preview)
|
||||
"""
|
||||
|
||||
DEFAULT_API_VERSION = "2024-02-15-preview"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str | None = None,
|
||||
api_key: str | None = None,
|
||||
deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
):
|
||||
"""Initialize the Azure OpenAI provider.
|
||||
|
||||
Args:
|
||||
endpoint: Azure OpenAI endpoint URL.
|
||||
Defaults to AZURE_OPENAI_ENDPOINT env var.
|
||||
api_key: API key for authentication.
|
||||
Defaults to AZURE_OPENAI_API_KEY env var.
|
||||
deployment: Deployment name to use.
|
||||
Defaults to AZURE_OPENAI_DEPLOYMENT env var.
|
||||
api_version: API version string.
|
||||
Defaults to AZURE_OPENAI_API_VERSION env var or latest.
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens in response.
|
||||
"""
|
||||
self.endpoint = (
|
||||
endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT", "")
|
||||
).rstrip("/")
|
||||
self.api_key = api_key or os.environ.get("AZURE_OPENAI_API_KEY", "")
|
||||
self.deployment = deployment or os.environ.get("AZURE_OPENAI_DEPLOYMENT", "")
|
||||
self.api_version = api_version or os.environ.get(
|
||||
"AZURE_OPENAI_API_VERSION", self.DEFAULT_API_VERSION
|
||||
)
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
def _get_api_url(self, deployment: str | None = None) -> str:
|
||||
"""Build the API URL for a given deployment.
|
||||
|
||||
Args:
|
||||
deployment: Deployment name. Uses default if not specified.
|
||||
|
||||
Returns:
|
||||
Full API URL for chat completions.
|
||||
"""
|
||||
deploy = deployment or self.deployment
|
||||
return (
|
||||
f"{self.endpoint}/openai/deployments/{deploy}"
|
||||
f"/chat/completions?api-version={self.api_version}"
|
||||
)
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options (deployment, temperature, max_tokens).
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
|
||||
Raises:
|
||||
ValueError: If required configuration is missing.
|
||||
requests.HTTPError: If the API request fails.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure OpenAI API key is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API with tool support.
|
||||
|
||||
Azure OpenAI uses the same format as OpenAI for tools.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts with 'role' and 'content'.
|
||||
tools: List of tool definitions in OpenAI format.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure OpenAI API key is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
|
||||
request_body = {
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", "") or "",
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class AzureOpenAIWithAADProvider(AzureOpenAIProvider):
|
||||
"""Azure OpenAI provider with Azure Active Directory authentication.
|
||||
|
||||
Uses Azure AD tokens instead of API keys for authentication.
|
||||
Requires azure-identity package for token acquisition.
|
||||
|
||||
Environment Variables:
|
||||
- AZURE_OPENAI_ENDPOINT: Azure OpenAI endpoint URL
|
||||
- AZURE_OPENAI_DEPLOYMENT: Default deployment name
|
||||
- AZURE_TENANT_ID: Azure AD tenant ID (optional)
|
||||
- AZURE_CLIENT_ID: Azure AD client ID (optional)
|
||||
- AZURE_CLIENT_SECRET: Azure AD client secret (optional)
|
||||
"""
|
||||
|
||||
SCOPE = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str | None = None,
|
||||
deployment: str | None = None,
|
||||
api_version: str | None = None,
|
||||
temperature: float = 0,
|
||||
max_tokens: int = 4096,
|
||||
credential=None,
|
||||
):
|
||||
"""Initialize the Azure OpenAI AAD provider.
|
||||
|
||||
Args:
|
||||
endpoint: Azure OpenAI endpoint URL.
|
||||
deployment: Deployment name to use.
|
||||
api_version: API version string.
|
||||
temperature: Sampling temperature (0-2).
|
||||
max_tokens: Maximum tokens in response.
|
||||
credential: Azure credential object. If not provided,
|
||||
uses DefaultAzureCredential.
|
||||
"""
|
||||
super().__init__(
|
||||
endpoint=endpoint,
|
||||
api_key="", # Not used with AAD
|
||||
deployment=deployment,
|
||||
api_version=api_version,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
self._credential = credential
|
||||
self._token = None
|
||||
self._token_expires_at = 0
|
||||
|
||||
def _get_token(self) -> str:
|
||||
"""Get an Azure AD token for authentication.
|
||||
|
||||
Returns:
|
||||
Bearer token string.
|
||||
|
||||
Raises:
|
||||
ImportError: If azure-identity is not installed.
|
||||
"""
|
||||
import time
|
||||
|
||||
# Return cached token if still valid (with 5 min buffer)
|
||||
if self._token and self._token_expires_at > time.time() + 300:
|
||||
return self._token
|
||||
|
||||
try:
|
||||
from azure.identity import DefaultAzureCredential
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"azure-identity package is required for AAD authentication. "
|
||||
"Install with: pip install azure-identity"
|
||||
)
|
||||
|
||||
if self._credential is None:
|
||||
self._credential = DefaultAzureCredential()
|
||||
|
||||
token = self._credential.get_token(self.SCOPE)
|
||||
self._token = token.token
|
||||
self._token_expires_at = token.expires_on
|
||||
|
||||
return self._token
|
||||
|
||||
def call(self, prompt: str, **kwargs) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API using AAD auth.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to send.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with the generated content.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
token = self._get_token()
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", ""),
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
|
||||
def call_with_tools(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""Make a call to the Azure OpenAI API with tool support using AAD auth.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts.
|
||||
tools: List of tool definitions.
|
||||
**kwargs: Additional options.
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and/or tool_calls.
|
||||
"""
|
||||
if not self.endpoint:
|
||||
raise ValueError("Azure OpenAI endpoint is required")
|
||||
if not self.deployment and not kwargs.get("deployment"):
|
||||
raise ValueError("Azure OpenAI deployment name is required")
|
||||
|
||||
deployment = kwargs.get("deployment", self.deployment)
|
||||
token = self._get_token()
|
||||
|
||||
request_body = {
|
||||
"messages": messages,
|
||||
"max_tokens": kwargs.get("max_tokens", self.max_tokens),
|
||||
"temperature": kwargs.get("temperature", self.temperature),
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_body["tools"] = tools
|
||||
request_body["tool_choice"] = kwargs.get("tool_choice", "auto")
|
||||
|
||||
response = requests.post(
|
||||
self._get_api_url(deployment),
|
||||
headers={
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=request_body,
|
||||
timeout=120,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
choice = data.get("choices", [{}])[0]
|
||||
message = choice.get("message", {})
|
||||
|
||||
# Parse tool calls if present
|
||||
tool_calls = None
|
||||
if message.get("tool_calls"):
|
||||
tool_calls = []
|
||||
for tc in message["tool_calls"]:
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=tc.get("id", ""),
|
||||
name=func.get("name", ""),
|
||||
arguments=args,
|
||||
)
|
||||
)
|
||||
|
||||
return LLMResponse(
|
||||
content=message.get("content", "") or "",
|
||||
model=data.get("model", deployment),
|
||||
provider="azure",
|
||||
tokens_used=data.get("usage", {}).get("total_tokens", 0),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
Reference in New Issue
Block a user