|
|
""" |
|
|
Streamlit app for testing machine-generated code detection model with explainability. |
|
|
|
|
|
This app allows users to: |
|
|
1. Input code snippets |
|
|
2. Get predictions on whether the code is human-written or machine-generated |
|
|
3. View feature importance and explanations for the prediction |
|
|
""" |
|
|
|
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import json |
|
|
import joblib |
|
|
import os |
|
|
from typing import Dict, List, Any |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
from plotly.subplots import make_subplots |
|
|
|
|
|
|
|
|
from code_analytics import extract_all_code_analytics, get_analytics_feature_names |
|
|
from entropy_weighted_perplexity import EntropyWeightedPerplexity |
|
|
|
|
|
|
|
|
try: |
|
|
import shap |
|
|
SHAP_AVAILABLE = True |
|
|
except ImportError: |
|
|
SHAP_AVAILABLE = False |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="AI Code Detection Tool", |
|
|
page_icon="π€", |
|
|
layout="wide", |
|
|
initial_sidebar_state="expanded" |
|
|
) |
|
|
|
|
|
class ModelLoader: |
|
|
"""Handle loading of trained models and metadata.""" |
|
|
|
|
|
def __init__(self, model_dir: str = "results"): |
|
|
self.model_dir = model_dir |
|
|
self.model = None |
|
|
self.metadata = None |
|
|
|
|
|
def load_model(self) -> bool: |
|
|
"""Load the trained model and metadata.""" |
|
|
try: |
|
|
model_path = os.path.join(self.model_dir, "trained_model.pkl") |
|
|
metadata_path = os.path.join(self.model_dir, "model_metadata.json") |
|
|
|
|
|
if not os.path.exists(model_path) or not os.path.exists(metadata_path): |
|
|
return False |
|
|
|
|
|
self.model = joblib.load(model_path) |
|
|
|
|
|
with open(metadata_path, 'r') as f: |
|
|
self.metadata = json.load(f) |
|
|
|
|
|
return True |
|
|
except Exception as e: |
|
|
st.error(f"Error loading model: {e}") |
|
|
return False |
|
|
|
|
|
def get_model_info(self) -> Dict[str, Any]: |
|
|
"""Get model information for display.""" |
|
|
if self.metadata is None: |
|
|
return {} |
|
|
|
|
|
return { |
|
|
"Model Type": self.metadata.get("model_type", "Unknown"), |
|
|
"Number of Features": len(self.metadata.get("feature_names", [])), |
|
|
"F1 Score": f"{self.metadata.get('metrics', {}).get('f1_macro', 0):.4f}", |
|
|
"Accuracy": f"{self.metadata.get('metrics', {}).get('accuracy', 0):.4f}", |
|
|
"Features Used": "Code Analytics" if not self.metadata.get('config', {}).get('features', {}).get('use_llm_features', False) else "Code Analytics + LLM" |
|
|
} |
|
|
|
|
|
class CodeAnalyzer: |
|
|
"""Handle code analysis and feature extraction.""" |
|
|
|
|
|
def __init__(self, model_loader: ModelLoader): |
|
|
self.model_loader = model_loader |
|
|
|
|
|
def extract_features(self, code: str) -> np.ndarray: |
|
|
""" |
|
|
Extract features from code using the same pipeline as training. |
|
|
""" |
|
|
try: |
|
|
if self.model_loader.metadata is None: |
|
|
raise ValueError("Model metadata not loaded") |
|
|
|
|
|
config = self.model_loader.metadata.get("config", {}) |
|
|
features_config = config.get("features", {}) |
|
|
|
|
|
all_features = [] |
|
|
|
|
|
|
|
|
if features_config.get("use_llm_features", False): |
|
|
|
|
|
ewp_calculator = EntropyWeightedPerplexity( |
|
|
model_name=config["model"]["name"], |
|
|
entropy_window_size=config["model"]["entropy_window_size"], |
|
|
entropy_weight=config["model"]["entropy_weight"], |
|
|
perplexity_weight=config["model"]["perplexity_weight"], |
|
|
) |
|
|
|
|
|
|
|
|
llm_features = ewp_calculator.calculate_entropy_weighted_score(code) |
|
|
all_features.extend([ |
|
|
llm_features["entropy_weighted_score"], |
|
|
llm_features["mean_entropy"], |
|
|
llm_features["mean_windowed_entropy"], |
|
|
llm_features["mean_cross_entropy"], |
|
|
llm_features["sequence_length"], |
|
|
llm_features["entropy_cross_entropy_ratio"], |
|
|
llm_features["windowed_raw_entropy_ratio"], |
|
|
]) |
|
|
|
|
|
|
|
|
if features_config.get("use_code_analytics", True): |
|
|
analytics_features = extract_all_code_analytics(code) |
|
|
|
|
|
analytics_feature_names = get_analytics_feature_names() |
|
|
for feature_name in analytics_feature_names: |
|
|
all_features.append(analytics_features.get(feature_name, 0.0)) |
|
|
|
|
|
return np.array(all_features) |
|
|
|
|
|
except Exception as e: |
|
|
st.error(f"Feature extraction failed: {e}") |
|
|
|
|
|
n_features = len(self.model_loader.metadata.get("feature_names", [])) |
|
|
return np.zeros(n_features) |
|
|
|
|
|
def predict(self, code: str) -> Dict[str, Any]: |
|
|
"""Make prediction and return results with explanations.""" |
|
|
if self.model_loader.model is None: |
|
|
return {"error": "Model not loaded"} |
|
|
|
|
|
try: |
|
|
|
|
|
features = self.extract_features(code) |
|
|
features = features.reshape(1, -1) |
|
|
|
|
|
|
|
|
prediction = self.model_loader.model.predict(features)[0] |
|
|
probability = self.model_loader.model.predict_proba(features)[0] |
|
|
|
|
|
|
|
|
feature_importance = self.get_feature_importance(features) |
|
|
|
|
|
return { |
|
|
"prediction": prediction, |
|
|
"probability": probability, |
|
|
"features": features[0], |
|
|
"feature_importance": feature_importance, |
|
|
"label": self.model_loader.metadata["label_mapping"][str(prediction)] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"Prediction failed: {e}"} |
|
|
|
|
|
def get_feature_importance(self, features: np.ndarray) -> Dict[str, float]: |
|
|
"""Get feature importance for the current prediction.""" |
|
|
try: |
|
|
if hasattr(self.model_loader.model, 'feature_importances_'): |
|
|
|
|
|
importances = self.model_loader.model.feature_importances_ |
|
|
elif hasattr(self.model_loader.model, 'coef_'): |
|
|
|
|
|
importances = np.abs(self.model_loader.model.coef_[0]) |
|
|
else: |
|
|
|
|
|
if hasattr(self.model_loader.model, 'estimators_'): |
|
|
importances = [] |
|
|
for estimator in self.model_loader.model.estimators_: |
|
|
if hasattr(estimator, 'feature_importances_'): |
|
|
importances.append(estimator.feature_importances_) |
|
|
if importances: |
|
|
importances = np.mean(importances, axis=0) |
|
|
else: |
|
|
importances = np.ones(len(features[0])) / len(features[0]) |
|
|
else: |
|
|
importances = np.ones(len(features[0])) / len(features[0]) |
|
|
|
|
|
feature_names = self.model_loader.metadata.get("feature_names", |
|
|
[f"Feature_{i}" for i in range(len(features[0]))]) |
|
|
|
|
|
return dict(zip(feature_names, importances)) |
|
|
|
|
|
except Exception as e: |
|
|
st.warning(f"Could not get feature importance: {e}") |
|
|
return {} |
|
|
|
|
|
def get_shap_explanation(self, code: str) -> Dict[str, Any]: |
|
|
"""Get SHAP explanations for the prediction.""" |
|
|
if not SHAP_AVAILABLE: |
|
|
return {"error": "SHAP not available"} |
|
|
|
|
|
try: |
|
|
|
|
|
features = self.extract_features(code).reshape(1, -1) |
|
|
|
|
|
|
|
|
if hasattr(self.model_loader.model, 'feature_importances_'): |
|
|
|
|
|
explainer = shap.TreeExplainer(self.model_loader.model) |
|
|
else: |
|
|
|
|
|
|
|
|
background_size = min(100, 10) |
|
|
background_features = np.random.normal( |
|
|
features.mean(), features.std(), (background_size, features.shape[1]) |
|
|
) |
|
|
explainer = shap.KernelExplainer( |
|
|
self.model_loader.model.predict_proba, background_features |
|
|
) |
|
|
|
|
|
|
|
|
shap_values = explainer.shap_values(features) |
|
|
|
|
|
|
|
|
if isinstance(shap_values, list): |
|
|
shap_values = shap_values[1] |
|
|
|
|
|
feature_names = self.model_loader.metadata.get("feature_names", |
|
|
[f"Feature_{i}" for i in range(features.shape[1])]) |
|
|
|
|
|
return { |
|
|
"shap_values": shap_values[0], |
|
|
"feature_names": feature_names, |
|
|
"base_value": explainer.expected_value if hasattr(explainer, 'expected_value') else 0.5, |
|
|
"feature_values": features[0] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"SHAP explanation failed: {e}"} |
|
|
|
|
|
def create_shap_waterfall_plot(shap_explanation: Dict[str, Any], top_n: int = 15): |
|
|
"""Create a SHAP waterfall-style plot showing feature contributions.""" |
|
|
if "error" in shap_explanation: |
|
|
return None |
|
|
|
|
|
shap_values = shap_explanation["shap_values"] |
|
|
feature_names = shap_explanation["feature_names"] |
|
|
feature_values = shap_explanation["feature_values"] |
|
|
base_value = shap_explanation.get("base_value", 0.5) |
|
|
|
|
|
|
|
|
feature_contributions = list(zip(feature_names, shap_values, feature_values)) |
|
|
feature_contributions.sort(key=lambda x: abs(x[1]), reverse=True) |
|
|
top_features = feature_contributions[:top_n] |
|
|
|
|
|
|
|
|
names = [f[0] for f in top_features] |
|
|
values = [f[1] for f in top_features] |
|
|
colors = ['green' if v > 0 else 'red' for v in values] |
|
|
|
|
|
fig = go.Figure(go.Bar( |
|
|
x=values, |
|
|
y=names, |
|
|
orientation='h', |
|
|
marker_color=colors, |
|
|
text=[f"{v:.4f}" for v in values], |
|
|
textposition="outside" |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f"SHAP Feature Contributions (Top {top_n})", |
|
|
xaxis_title="SHAP Value (contribution to prediction)", |
|
|
yaxis_title="Features", |
|
|
height=600, |
|
|
yaxis={'categoryorder': 'total ascending'}, |
|
|
showlegend=False |
|
|
) |
|
|
|
|
|
|
|
|
fig.add_vline(x=0, line_dash="dash", line_color="black", opacity=0.5) |
|
|
|
|
|
return fig |
|
|
|
|
|
def create_feature_importance_plot(feature_importance: Dict[str, float], top_n: int = 20): |
|
|
"""Create feature importance visualization.""" |
|
|
if not feature_importance: |
|
|
return None |
|
|
|
|
|
|
|
|
sorted_features = sorted(feature_importance.items(), key=lambda x: abs(x[1]), reverse=True) |
|
|
top_features = sorted_features[:top_n] |
|
|
|
|
|
feature_names = [f[0] for f in top_features] |
|
|
importance_values = [f[1] for f in top_features] |
|
|
|
|
|
|
|
|
fig = go.Figure(go.Bar( |
|
|
x=importance_values, |
|
|
y=feature_names, |
|
|
orientation='h', |
|
|
marker_color=px.colors.qualitative.Set3 |
|
|
)) |
|
|
|
|
|
fig.update_layout( |
|
|
title=f"Top {top_n} Most Important Features", |
|
|
xaxis_title="Feature Importance", |
|
|
yaxis_title="Features", |
|
|
height=600, |
|
|
yaxis={'categoryorder': 'total ascending'} |
|
|
) |
|
|
|
|
|
return fig |
|
|
|
|
|
def create_prediction_gauge(probability: np.ndarray, prediction: int): |
|
|
"""Create a gauge chart showing prediction confidence.""" |
|
|
confidence = max(probability) |
|
|
|
|
|
fig = go.Figure(go.Indicator( |
|
|
mode="gauge+number+delta", |
|
|
value=confidence * 100, |
|
|
domain={'x': [0, 1], 'y': [0, 1]}, |
|
|
title={'text': "Prediction Confidence (%)"}, |
|
|
gauge={ |
|
|
'axis': {'range': [None, 100]}, |
|
|
'bar': {'color': "lightgreen" if prediction == 0 else "lightcoral"}, |
|
|
'steps': [ |
|
|
{'range': [0, 50], 'color': "lightgray"}, |
|
|
{'range': [50, 80], 'color': "yellow"}, |
|
|
{'range': [80, 100], 'color': "lightgreen"} |
|
|
], |
|
|
'threshold': { |
|
|
'line': {'color': "red", 'width': 4}, |
|
|
'thickness': 0.75, |
|
|
'value': 90 |
|
|
} |
|
|
} |
|
|
)) |
|
|
|
|
|
fig.update_layout(height=300) |
|
|
return fig |
|
|
|
|
|
def main(): |
|
|
st.title("π€ AI Code Detection Tool") |
|
|
st.markdown("### Detect whether code is human-written or machine-generated with explainable AI") |
|
|
|
|
|
|
|
|
if 'model_loader' not in st.session_state: |
|
|
st.session_state.model_loader = ModelLoader() |
|
|
st.session_state.model_loaded = False |
|
|
|
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Model Information") |
|
|
|
|
|
|
|
|
if not st.session_state.model_loaded: |
|
|
if st.button("Load Model"): |
|
|
with st.spinner("Loading model..."): |
|
|
if st.session_state.model_loader.load_model(): |
|
|
st.session_state.model_loaded = True |
|
|
st.success("Model loaded successfully!") |
|
|
else: |
|
|
st.error("Failed to load model. Please ensure the model files exist in the 'results' directory.") |
|
|
|
|
|
if st.session_state.model_loaded: |
|
|
model_info = st.session_state.model_loader.get_model_info() |
|
|
for key, value in model_info.items(): |
|
|
st.metric(key, value) |
|
|
|
|
|
|
|
|
if not st.session_state.model_loaded: |
|
|
st.warning("β οΈ Please load the model first using the sidebar.") |
|
|
st.info("Make sure you have trained a model using the main script with `save_model: true` in the config.") |
|
|
return |
|
|
|
|
|
|
|
|
analyzer = CodeAnalyzer(st.session_state.model_loader) |
|
|
|
|
|
|
|
|
st.header("π Enter Code to Analyze") |
|
|
|
|
|
|
|
|
examples = { |
|
|
"Python Function": '''def fibonacci(n): |
|
|
if n <= 1: |
|
|
return n |
|
|
return fibonacci(n-1) + fibonacci(n-2)''', |
|
|
|
|
|
"Simple Loop": '''for i in range(10): |
|
|
print(f"Number: {i}") |
|
|
if i % 2 == 0: |
|
|
print("Even")''', |
|
|
|
|
|
"Class Definition": '''class Calculator: |
|
|
def __init__(self): |
|
|
self.history = [] |
|
|
|
|
|
def add(self, a, b): |
|
|
result = a + b |
|
|
self.history.append(f"{a} + {b} = {result}") |
|
|
return result''' |
|
|
} |
|
|
|
|
|
|
|
|
col1, col2 = st.columns([1, 3]) |
|
|
with col1: |
|
|
selected_example = st.selectbox("Load Example:", [""] + list(examples.keys())) |
|
|
|
|
|
|
|
|
if selected_example: |
|
|
code_input = st.text_area("Code:", examples[selected_example], height=200, key="code_input") |
|
|
else: |
|
|
code_input = st.text_area("Code:", height=200, placeholder="Enter your code here...", key="code_input") |
|
|
|
|
|
|
|
|
if code_input.strip(): |
|
|
try: |
|
|
import ast |
|
|
ast.parse(code_input) |
|
|
st.success("β
Valid Python syntax") |
|
|
except SyntaxError as e: |
|
|
st.warning(f"β οΈ Syntax error detected: {e}") |
|
|
st.info("Note: The model can still analyze syntactically incorrect code, but results may be less reliable.") |
|
|
except Exception: |
|
|
st.info("Code validation skipped (not standard Python)") |
|
|
|
|
|
|
|
|
with st.expander("Analysis Options"): |
|
|
show_all_features = st.checkbox("Show all features in results", value=False) |
|
|
use_shap = st.checkbox("Enable SHAP explanations", value=SHAP_AVAILABLE, disabled=not SHAP_AVAILABLE) |
|
|
if not SHAP_AVAILABLE: |
|
|
st.info("Install SHAP (`pip install shap`) for advanced explanations") |
|
|
|
|
|
|
|
|
if st.button("π Analyze Code", type="primary"): |
|
|
if not code_input.strip(): |
|
|
st.warning("Please enter some code to analyze.") |
|
|
return |
|
|
|
|
|
with st.spinner("Analyzing code..."): |
|
|
result = analyzer.predict(code_input) |
|
|
|
|
|
if "error" in result: |
|
|
st.error(result["error"]) |
|
|
return |
|
|
|
|
|
|
|
|
st.header("π Analysis Results") |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
|
|
|
|
with col1: |
|
|
st.metric( |
|
|
"Prediction", |
|
|
result["label"], |
|
|
delta=f"{max(result['probability']):.1%} confidence" |
|
|
) |
|
|
|
|
|
with col2: |
|
|
human_prob = result["probability"][0] |
|
|
machine_prob = result["probability"][1] |
|
|
st.metric("Human-written", f"{human_prob:.1%}") |
|
|
st.metric("Machine-generated", f"{machine_prob:.1%}") |
|
|
|
|
|
with col3: |
|
|
|
|
|
gauge_fig = create_prediction_gauge(result["probability"], result["prediction"]) |
|
|
st.plotly_chart(gauge_fig, use_container_width=True) |
|
|
|
|
|
|
|
|
if result["feature_importance"]: |
|
|
st.header("π Feature Importance & Explanations") |
|
|
|
|
|
|
|
|
tabs = ["Global Importance", "Feature Values"] |
|
|
if SHAP_AVAILABLE: |
|
|
tabs.append("SHAP Explanations") |
|
|
|
|
|
tab_objects = st.tabs(tabs) |
|
|
|
|
|
with tab_objects[0]: |
|
|
st.subheader("Model's Overall Feature Importance") |
|
|
fig = create_feature_importance_plot(result["feature_importance"], top_n=20) |
|
|
if fig: |
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
|
|
|
st.subheader("Top Contributing Features:") |
|
|
sorted_features = sorted(result["feature_importance"].items(), |
|
|
key=lambda x: abs(x[1]), reverse=True) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
with col1: |
|
|
st.write("**Most Important:**") |
|
|
for i, (feature, importance) in enumerate(sorted_features[:10], 1): |
|
|
st.write(f"{i}. **{feature}**: {importance:.4f}") |
|
|
|
|
|
with col2: |
|
|
st.write("**Feature Description:**") |
|
|
st.info("These are the features the model finds most important globally across all predictions.") |
|
|
|
|
|
with tab_objects[1]: |
|
|
st.subheader("Current Code's Feature Values") |
|
|
|
|
|
|
|
|
feature_names = st.session_state.model_loader.metadata.get("feature_names", []) |
|
|
feature_values = result["features"] |
|
|
|
|
|
if len(feature_names) == len(feature_values): |
|
|
feature_df = pd.DataFrame({ |
|
|
"Feature": feature_names, |
|
|
"Value": feature_values, |
|
|
"Global_Importance": [result["feature_importance"].get(name, 0) for name in feature_names] |
|
|
}).sort_values("Global_Importance", ascending=False) |
|
|
|
|
|
st.dataframe(feature_df, height=400) |
|
|
|
|
|
if SHAP_AVAILABLE and len(tab_objects) > 2: |
|
|
with tab_objects[2]: |
|
|
st.subheader("SHAP Analysis: Why This Prediction?") |
|
|
|
|
|
with st.spinner("Computing SHAP explanations..."): |
|
|
shap_result = analyzer.get_shap_explanation(code_input) |
|
|
|
|
|
if "error" not in shap_result: |
|
|
shap_fig = create_shap_waterfall_plot(shap_result, top_n=15) |
|
|
if shap_fig: |
|
|
st.plotly_chart(shap_fig, use_container_width=True) |
|
|
|
|
|
st.info(""" |
|
|
**How to read SHAP values:** |
|
|
- Green bars push the prediction toward "Machine-generated" |
|
|
- Red bars push the prediction toward "Human-written" |
|
|
- Longer bars = stronger influence on this specific prediction |
|
|
- Values show how much each feature contributed to moving the prediction from the baseline |
|
|
""") |
|
|
else: |
|
|
st.warning(f"SHAP analysis failed: {shap_result['error']}") |
|
|
st.info("Falling back to global feature importance above.") |
|
|
else: |
|
|
st.warning("Feature importance not available for this model.") |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("### About This Tool") |
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.info(""" |
|
|
**Purpose**: This tool helps detect whether code was written by humans or generated by AI. |
|
|
|
|
|
**Method**: Uses static code analysis with machine learning, focusing on patterns in: |
|
|
- Code structure and complexity |
|
|
- Naming conventions and style |
|
|
- Syntactic patterns and AST features |
|
|
- Error handling and control flow |
|
|
""") |
|
|
|
|
|
with col2: |
|
|
st.warning( |
|
|
""" |
|
|
**Limitations**: |
|
|
- Works with Python code only |
|
|
- Accuracy depends on code length and complexity |
|
|
- Results are probabilistic, not definitive |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|