Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ai/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ dependencies {

implementation("com.github.haifengl:smile-core:2.6.0")
implementation("com.github.haifengl:smile-plot:2.6.0")
implementation("com.github.haifengl:smile-data:2.6.0")
implementation("com.github.haifengl:smile-anomaly:2.6.0")

testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5")
Expand Down
4 changes: 3 additions & 1 deletion ai/src/main/kotlin/lab/Application.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package lab
import org.springframework.boot.autoconfigure.SpringBootApplication
import org.springframework.boot.runApplication

@SpringBootApplication
@SpringBootApplication(
scanBasePackages = ["lab.api", "lab.`ai-model`"]
)
class Application

fun main(args: Array<String>) {
Expand Down
16 changes: 16 additions & 0 deletions ai/src/main/kotlin/lab/ai-model/NormalizationUtil.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package lab.`ai-model`

object NormalizationUtil {
fun normalize(features: Array<DoubleArray>): Array<DoubleArray> {
val numFeatures = features.first().size
val minVals = DoubleArray(numFeatures) { idx -> features.minOf { it[idx] } }
val maxVals = DoubleArray(numFeatures) { idx -> features.maxOf { it[idx] } }

return features.map { f ->
DoubleArray(numFeatures) { i ->
if (maxVals[i] == minVals[i]) 0.0
else (f[i] - minVals[i]) / (maxVals[i] - minVals[i])
}
}.toTypedArray()
}
}
74 changes: 74 additions & 0 deletions ai/src/main/kotlin/lab/ai-model/gc/GcAnomalyDetector.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package lab.`ai-model`.gc

import lab.`ai-model`.NormalizationUtil
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import smile.clustering.KMeans
import smile.math.distance.EuclideanDistance
import java.io.File
import java.io.ObjectOutputStream

@Component
class GcAnomalyDetector {

private val extractor: GcFeatureExtractor by lazy { GcFeatureExtractor }
private val normalizationUtil: NormalizationUtil by lazy { NormalizationUtil }

private val projectRootDir: String = System.getProperty("user.dir")
private val modelDir = File("$projectRootDir/ai-models/gc-anomaly").apply { mkdirs() }

private lateinit var model: KMeans

private val log = LoggerFactory.getLogger(GcAnomalyDetector::class.java)

fun train(k: Int = 3) {
log.info("Start GcAnomalyDetector(KMeans) training...")

val dataList = getDataList()
if (dataList.isEmpty()) {
log.warn("No data available for training.")
return
}

val features = dataList.map { extractor.extract(it) }.toTypedArray()
val normalized = normalizationUtil.normalize(features)

// KMeans 학습
model = KMeans.fit(normalized, k)
log.info("✅ KMeans training completed with $k clusters.")
saveModel()
}

fun predict(data: GcTrainData): Double {
val dist = EuclideanDistance()
val features = extractor.extract(data)
val cluster = model.predict(features)
val centroid = model.centroids[cluster]
val distance = dist.d(features, centroid)
log.info("🔎 Distance to cluster center: %.4f".format(distance))
return distance
}

private fun saveModel() {
val file = File(modelDir, "gc_anomaly_kmeans.model")
ObjectOutputStream(file.outputStream().buffered()).use { oos ->
oos.writeObject(model)
}
log.info("💾 Saved KMeans anomaly model → ${file.absolutePath}")
}

private fun getDataList(isTestSet: Boolean = true): List<GcTrainData> {
return if(isTestSet) {
listOf(
GcTrainData(120, 500, 40, 1.2, 400_000, "G1", label = 1),
GcTrainData(200, 800, 350, 3.8, 1_200_000, "G1", label = 1),
GcTrainData(85, 260, 12, 0.9, 250_000, "Parallel", label = 1),
GcTrainData(600, 1500, 900, 6.8, 2_200_000, "G1", label = 1),
GcTrainData(95, 340, 18, 1.4, 370_000, "Serial", label = 1)
)
} else {
// TODO heesung feature
listOf()
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package lab.`ai-model`.gc

import lab.`ai-model`.NormalizationUtil
import org.slf4j.LoggerFactory
import org.springframework.stereotype.Component
import smile.classification.LogisticRegression
Expand All @@ -12,14 +13,17 @@ import java.io.ObjectOutputStream
import java.util.*

@Component
class GcTrainer {
class GcLogisticRegressionTrainer {
private val extractor: GcFeatureExtractor by lazy { GcFeatureExtractor }
private lateinit var model: LogisticRegression
private val log = LoggerFactory.getLogger(GcTrainer::class.java)
private val normalizationUtil: NormalizationUtil by lazy { NormalizationUtil }

private val projectRootDir: String = System.getProperty("user.dir")
private val modelDir = File("$projectRootDir/ai-models/gc-model").apply { mkdirs() }

private lateinit var model: LogisticRegression

private val log = LoggerFactory.getLogger(GcLogisticRegressionTrainer::class.java)

fun train() {
log.info("Start GcTrainer training...")
val dataList = getDataList()
Expand All @@ -37,7 +41,7 @@ class GcTrainer {
log.info("Sample training data: ${dataList.take(3)}")

val features = dataList.map { extractor.extract(it) }.toTypedArray()
val normalizedFeatures = normalize(features)
val normalizedFeatures = normalizationUtil.normalize(features)
val labels = dataList.map { it.label }.toIntArray()

val df = DataFrame.of(
Expand Down Expand Up @@ -68,19 +72,6 @@ class GcTrainer {
saveModel("test")
}

private fun normalize(features: Array<DoubleArray>): Array<DoubleArray> {
val numFeatures = features.first().size
val minVals = DoubleArray(numFeatures) { idx -> features.minOf { it[idx] } }
val maxVals = DoubleArray(numFeatures) { idx -> features.maxOf { it[idx] } }

return features.map { f ->
DoubleArray(numFeatures) { i ->
if (maxVals[i] == minVals[i]) 0.0
else (f[i] - minVals[i]) / (maxVals[i] - minVals[i])
}
}.toTypedArray()
}

private fun saveModel(key: String) {
val m = model ?: run {
log.error("Model not trained. Cannot save [$key].")
Expand All @@ -96,14 +87,18 @@ class GcTrainer {
log.info("💾 Saved model [$key] → ${file.absolutePath}")
}

private fun getDataList(): List<GcTrainData> {
// TODO - khope heesung이 만들어준 data get에서 가져와쓰는걸로 수정
return listOf(
GcTrainData(100, 400, 30, 1.2, 300_000, "G1", label = 1),
GcTrainData(150, 700, 300, 3.8, 1_000_000, "G1", label = 0),
GcTrainData(80, 250, 15, 0.8, 200_000, "Parallel", label = 1),
GcTrainData(400, 1200, 700, 6.2, 2_000_000, "G1", label = 0),
GcTrainData(90, 320, 20, 1.5, 350_000, "Serial", label = 1)
)
private fun getDataList(isTestSet: Boolean = true): List<GcTrainData> {
return if(isTestSet) {
listOf(
GcTrainData(120, 500, 40, 1.2, 400_000, "G1", label = 1),
GcTrainData(200, 800, 350, 3.8, 1_200_000, "G1", label = 1),
GcTrainData(85, 260, 12, 0.9, 250_000, "Parallel", label = 1),
GcTrainData(600, 1500, 900, 6.8, 2_200_000, "G1", label = 1),
GcTrainData(95, 340, 18, 1.4, 370_000, "Serial", label = 1)
)
} else {
// TODO heesung feature
listOf()
}
}
}
15 changes: 11 additions & 4 deletions ai/src/main/kotlin/lab/api/ApiController.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
package lab.api

import lab.`ai-model`.gc.GcTrainer
import lab.`ai-model`.gc.GcAnomalyDetector
import lab.`ai-model`.gc.GcLogisticRegressionTrainer
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.annotation.RestController

@RestController
class ApiController(
private val gcTrainer: GcTrainer
private val gcLogisticRegressionTrainer: GcLogisticRegressionTrainer,
private val gcAnomalyDetector: GcAnomalyDetector,
) {

@GetMapping("/api/train")
fun train() {
gcTrainer.train()
fun train(@RequestParam trainModelName: String?) {
when(trainModelName) {
"gc_logistic_regression" -> gcLogisticRegressionTrainer.train()
"gc_anomaly_detector" -> gcAnomalyDetector.train()
else -> throw IllegalArgumentException("Unknown model name: $trainModelName")
}
}
}