Spaces:
Running
Running
File size: 5,704 Bytes
f01c9d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
class LocalClassifier {
constructor() {
this.weights = new Map(); // tag -> weight vector
this.biases = new Map(); // tag -> bias
this.learningRate = 0.01;
this.featureDim = 512; // CLAP embedding dimension
this.isInitialized = false;
}
initialize(featureDim = 512) {
this.featureDim = featureDim;
this.isInitialized = true;
}
// Simple logistic regression training
trainOnFeedback(features, tag, feedback) {
if (!this.isInitialized) {
this.initialize();
}
// Convert feedback to target value
let target;
switch (feedback) {
case 'positive':
target = 1.0;
break;
case 'negative':
target = 0.0;
break;
case 'custom':
target = 1.0;
break;
default:
return; // Skip unknown feedback
}
// Initialize weights for new tag
if (!this.weights.has(tag)) {
this.weights.set(tag, new Array(this.featureDim).fill(0).map(() =>
(Math.random() - 0.5) * 0.01
));
this.biases.set(tag, 0);
}
const weights = this.weights.get(tag);
const bias = this.biases.get(tag);
// Forward pass
let logit = bias;
for (let i = 0; i < features.length; i++) {
logit += weights[i] * features[i];
}
// Sigmoid activation
const prediction = 1 / (1 + Math.exp(-logit));
// Compute gradient
const error = prediction - target;
// Update weights and bias
for (let i = 0; i < features.length; i++) {
weights[i] -= this.learningRate * error * features[i];
}
this.biases.set(tag, bias - this.learningRate * error);
// Store updated weights
this.weights.set(tag, weights);
}
// Predict confidence for a tag given features
predict(features, tag) {
if (!this.weights.has(tag)) {
return null; // No training data for this tag
}
const weights = this.weights.get(tag);
const bias = this.biases.get(tag);
let logit = bias;
for (let i = 0; i < Math.min(features.length, weights.length); i++) {
logit += weights[i] * features[i];
}
// Sigmoid activation
return 1 / (1 + Math.exp(-logit));
}
// Get all predictions for given features
predictAll(features, candidateTags) {
const predictions = [];
for (const tag of candidateTags) {
const confidence = this.predict(features, tag);
if (confidence !== null) {
predictions.push({ tag, confidence });
}
}
return predictions.sort((a, b) => b.confidence - a.confidence);
}
// Retrain on batch of feedback data
retrainOnBatch(feedbackData) {
for (const item of feedbackData) {
if (item.audioFeatures && item.correctedTags) {
// Create simple features from audio metadata
const features = this.extractSimpleFeatures(item.audioFeatures);
// Train on corrected tags
for (const tagData of item.correctedTags) {
this.trainOnFeedback(features, tagData.tag, tagData.feedback);
}
}
}
}
// Extract simple features from audio metadata
extractSimpleFeatures(audioFeatures) {
// Create a simple feature vector from audio metadata
// In a real implementation, this would use actual CLAP embeddings
const features = new Array(this.featureDim).fill(0);
if (audioFeatures) {
// Use basic audio properties to create pseudo-features
features[0] = audioFeatures.duration / 60; // Duration in minutes
features[1] = audioFeatures.sampleRate / 48000; // Normalized sample rate
features[2] = audioFeatures.numberOfChannels; // Number of channels
// Fill remaining with small random values based on hash of properties
const seed = this.simpleHash(JSON.stringify(audioFeatures));
for (let i = 3; i < this.featureDim; i++) {
features[i] = this.seededRandom(seed + i) * 0.1;
}
}
return features;
}
// Simple hash function for seeded random
simpleHash(str) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
const char = str.charCodeAt(i);
hash = ((hash << 5) - hash) + char;
hash = hash & hash; // Convert to 32-bit integer
}
return Math.abs(hash);
}
// Seeded random number generator
seededRandom(seed) {
const x = Math.sin(seed) * 10000;
return x - Math.floor(x);
}
// Save model to localStorage
saveModel() {
const modelData = {
weights: Object.fromEntries(this.weights),
biases: Object.fromEntries(this.biases),
featureDim: this.featureDim,
learningRate: this.learningRate
};
localStorage.setItem('clipTaggerModel', JSON.stringify(modelData));
}
// Load model from localStorage
loadModel() {
const saved = localStorage.getItem('clipTaggerModel');
if (saved) {
try {
const modelData = JSON.parse(saved);
this.weights = new Map(Object.entries(modelData.weights));
this.biases = new Map(Object.entries(modelData.biases));
this.featureDim = modelData.featureDim || 512;
this.learningRate = modelData.learningRate || 0.01;
this.isInitialized = true;
return true;
} catch (error) {
console.error('Error loading model:', error);
}
}
return false;
}
// Get model statistics
getModelStats() {
return {
trainedTags: this.weights.size,
featureDim: this.featureDim,
learningRate: this.learningRate,
tags: Array.from(this.weights.keys())
};
}
// Clear the model
clearModel() {
this.weights.clear();
this.biases.clear();
localStorage.removeItem('clipTaggerModel');
}
}
export default LocalClassifier; |