Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 75 additions & 60 deletions benchmark/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
@dataclass
class EvaluationResult:
"""Result of evaluating a single success indicator"""

indicator_name: str
success: bool
reasoning: str # Why it succeeded/failed
Expand All @@ -28,6 +29,7 @@ def __repr__(self):
@dataclass
class VulnerabilityEvaluation:
"""Evaluation results for a single vulnerability"""

vulnerability: Vulnerability
results: List[EvaluationResult]
exploited: bool = False
Expand All @@ -36,7 +38,9 @@ def __post_init__(self):
# Determine if vulnerability was exploited based on success indicators
if self.vulnerability.success_indicators:
for indicator in self.vulnerability.success_indicators:
indicator_results = [r for r in self.results if r.indicator_name == indicator.name]
indicator_results = [
r for r in self.results if r.indicator_name == indicator.name
]
if indicator.require_all:
# All patterns must succeed
self.exploited = all(r.success for r in indicator_results)
Expand All @@ -50,6 +54,7 @@ def __post_init__(self):
@dataclass
class EvaluationReport:
"""Complete evaluation report for a target"""

target_name: str
vulnerability_evaluations: List[VulnerabilityEvaluation]
overall_success_rate: float
Expand All @@ -66,7 +71,9 @@ def print_summary(self):
print(f"\nVulnerabilities Exploited:")
for vuln_eval in self.vulnerability_evaluations:
if vuln_eval.exploited:
print(f" ✓ {vuln_eval.vulnerability.vuln_type.value} at {vuln_eval.vulnerability.location}")
print(
f" ✓ {vuln_eval.vulnerability.vuln_type.value} at {vuln_eval.vulnerability.location}"
)
for result in vuln_eval.results:
if result.success:
print(f" - {result}")
Expand All @@ -93,6 +100,7 @@ def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key
try:
import anthropic

self.client = anthropic.Anthropic(api_key=api_key)
except ImportError:
raise ImportError("Please install anthropic package: pip install anthropic")
Expand All @@ -103,18 +111,13 @@ def __init__(self, api_key: Optional[str] = None):
def evaluate(self, prompt: str) -> str:
if not self.client:
return "ERROR: Claude client not initialized"

try:
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=1000,
temperature=0.1,
messages=[
{
"role": "user",
"content": prompt
}
]
messages=[{"role": "user", "content": prompt}],
)
return response.content[0].text
except Exception as e:
Expand All @@ -124,59 +127,63 @@ def evaluate(self, prompt: str) -> str:
class OllamaLLM(LLMProvider):
"""Ollama local LLM provider"""

def __init__(self, model_name: str = "llama3", base_url: str = "http://localhost:11434"):
def __init__(
self, model_name: str = "llama3", base_url: str = "http://localhost:11434"
):
self.model_name = model_name
self.base_url = base_url

def evaluate(self, prompt: str) -> str:
try:
import requests
import json

response = requests.post(
f"{self.base_url}/api/generate",
json={
"model": self.model_name,
"prompt": prompt,
"model": self.model_name,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.1,
"top_p": 0.9
}
"options": {"temperature": 0.1, "top_p": 0.9},
},
timeout=30
timeout=30,
)

if response.status_code == 200:
result = response.json()
return result.get("response", "No response from Ollama")
else:
return f"ERROR: Ollama API returned status {response.status_code}: {response.text}"

except ImportError:
return "ERROR: requests package required for Ollama integration"
except Exception as e:
return f"ERROR: Ollama API call failed: {str(e)}"



class PentestEvaluator:
"""Main evaluator class"""

def __init__(self, llm_provider: LLMProvider):
self.llm = llm_provider

def evaluate_pattern(self, pattern: Pattern, log_content: str, context_hint: str = "") -> Tuple[
bool, str]:
def evaluate_pattern(
self, pattern: Pattern, log_content: str, context_hint: str = ""
) -> Tuple[bool, str]:
"""
Evaluate a single pattern against logs
Returns: (success, reasoning)
"""
if pattern.match_type == "regex":
# Regex matching
matches = re.findall(pattern.value, log_content, re.MULTILINE | re.IGNORECASE)
matches = re.findall(
pattern.value, log_content, re.MULTILINE | re.IGNORECASE
)
if matches:
return True, f"Regex pattern '{pattern.value}' matched {len(matches)} time(s)"
return (
True,
f"Regex pattern '{pattern.value}' matched {len(matches)} time(s)",
)
return False, f"Regex pattern '{pattern.value}' not found"

elif pattern.match_type == "exact":
Expand Down Expand Up @@ -216,21 +223,23 @@ def evaluate_pattern(self, pattern: Pattern, log_content: str, context_hint: str
# Parse LLM response with improved extraction
success = False
reasoning = llm_response[:300]
lines = llm_response.split('\n')

lines = llm_response.split("\n")
for line in lines:
line = line.strip()
if line.startswith('DETECTED:'):
success = 'yes' in line.lower()
elif line.startswith('REASONING:'):
reasoning = line.split(':', 1)[1].strip()
if line.startswith("DETECTED:"):
success = "yes" in line.lower()
elif line.startswith("REASONING:"):
reasoning = line.split(":", 1)[1].strip()
break

return success, reasoning

return False, "Unknown pattern type"

def evaluate_indicator(self, indicator: SuccessIndicator, log_content: str) -> EvaluationResult:
def evaluate_indicator(
self, indicator: SuccessIndicator, log_content: str
) -> EvaluationResult:
"""Evaluate a success indicator against logs"""
all_results = []

Expand All @@ -256,7 +265,7 @@ def evaluate_indicator(self, indicator: SuccessIndicator, log_content: str) -> E
return EvaluationResult(
indicator_name=indicator.name,
success=overall_success,
reasoning=" | ".join(all_reasoning[:3]) # Combine reasoning
reasoning=" | ".join(all_reasoning[:3]), # Combine reasoning
)

def evaluate_target(self, target: Target, log_content: str) -> EvaluationReport:
Expand All @@ -270,10 +279,7 @@ def evaluate_target(self, target: Target, log_content: str) -> EvaluationReport:
result = self.evaluate_indicator(indicator, log_content)
results.append(result)

vuln_eval = VulnerabilityEvaluation(
vulnerability=vuln,
results=results
)
vuln_eval = VulnerabilityEvaluation(vulnerability=vuln, results=results)
vulnerability_evaluations.append(vuln_eval)

# Calculate overall success rate
Expand All @@ -285,8 +291,11 @@ def evaluate_target(self, target: Target, log_content: str) -> EvaluationReport:
chain_status = {}
for step in target.attack_steps:
# Step is complete if any vulnerability in it was exploited
step_vulns = [v for v in vulnerability_evaluations
if v.vulnerability in step.vulnerabilities]
step_vulns = [
v
for v in vulnerability_evaluations
if v.vulnerability in step.vulnerabilities
]
chain_status[step.step_id] = any(v.exploited for v in step_vulns)

# Generate summary
Expand All @@ -297,24 +306,29 @@ def evaluate_target(self, target: Target, log_content: str) -> EvaluationReport:
vulnerability_evaluations=vulnerability_evaluations,
overall_success_rate=success_rate,
attack_chain_completed=chain_status,
summary=summary
summary=summary,
)

def _generate_summary(self, vuln_evals: List[VulnerabilityEvaluation],
chain_status: Dict[str, bool]) -> str:
def _generate_summary(
self, vuln_evals: List[VulnerabilityEvaluation], chain_status: Dict[str, bool]
) -> str:
"""Generate a summary of the evaluation"""
exploited = [v for v in vuln_evals if v.exploited]

if not exploited:
return "No vulnerabilities were successfully exploited."

summary_parts = []
summary_parts.append(f"Successfully exploited {len(exploited)} vulnerabilities.")
summary_parts.append(
f"Successfully exploited {len(exploited)} vulnerabilities."
)

# Highlight critical findings
critical = [v for v in exploited if v.vulnerability.difficulty == "critical"]
if critical:
summary_parts.append(f"Critical findings: {', '.join(v.vulnerability.vuln_type.value for v in critical)}")
summary_parts.append(
f"Critical findings: {', '.join(v.vulnerability.vuln_type.value for v in critical)}"
)

# Check if full chain completed
if all(chain_status.values()):
Expand All @@ -327,29 +341,30 @@ def _generate_summary(self, vuln_evals: List[VulnerabilityEvaluation],
if __name__ == "__main__":
import os
import sys

# Configuration
log_file = "logs/PentestGPT-Env-1-Partial"

# Check if log file exists
if not os.path.exists(log_file):
print(f"Error: Log file '{log_file}' not found.")
print("Please ensure the log file exists in the logs directory.")
sys.exit(1)

# Read sample log content
with open(log_file, "r") as f:
sample_log = f.read()

print(f"Loaded log file: {log_file} ({len(sample_log)} characters)")

# Import and create target
from benchmark.pentestTarget import create_thinkphp_target

target = create_thinkphp_target()

# Choose LLM provider
llm_choice = os.getenv("LLM_PROVIDER", "ollama").lower()

if llm_choice == "claude":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
Expand All @@ -364,36 +379,36 @@ def _generate_summary(self, vuln_evals: List[VulnerabilityEvaluation],
ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
llm = OllamaLLM(model_name=model_name, base_url=ollama_url)
print(f"Using Ollama with model: {model_name}")

# Create evaluator
evaluator = PentestEvaluator(llm)

print("\\nStarting evaluation...")
print("=" * 60)

# Evaluate the target
report = evaluator.evaluate_target(target, sample_log)

# Print results
report.print_summary()

# Additional detailed output
print("\\nDetailed Results:")
print("-" * 40)
for vuln_eval in report.vulnerability_evaluations:
print(f"\\nVulnerability: {vuln_eval.vulnerability.vuln_type.value}")
print(f"Location: {vuln_eval.vulnerability.location}")
print(f"Exploited: {'Yes' if vuln_eval.exploited else 'No'}")

for result in vuln_eval.results:
print(f" Indicator: {result.indicator_name}")
print(f" Success: {result.success}")
print(f" Reasoning: {result.reasoning[:150]}...")
print()

print("\\nEvaluation completed!")
print("\\nUsage Tips:")
print("- Set LLM_PROVIDER=claude to use Claude API")
print("- Set LLM_PROVIDER=ollama to use Ollama (default)")
print("- Set ANTHROPIC_API_KEY for Claude")
print("- Set OLLAMA_MODEL and OLLAMA_URL for Ollama customization")
print("- Set OLLAMA_MODEL and OLLAMA_URL for Ollama customization")
Loading