using GBDTs using MultivariateTimeSeries X, y = read_data_labeled(joinpath(dirname(pathof(GBDTs)), "..", "data", "auslan_youtube8")); grammar = @grammar begin b = G(bvec) | F(bvec) | G(implies(bvec,bvec)) bvec = and(bvec, bvec) bvec = or(bvec, bvec) bvec = not(bvec) bvec = lt(rvec, rvec) bvec = lte(rvec, rvec) bvec = gt(rvec, rvec) bvec = gte(rvec, rvec) bvec = f_lt(x, xid, v, vid) bvec = f_lte(x, xid, v, vid) bvec = f_gt(x, xid, v, vid) bvec = f_gte(x, xid, v, vid) rvec = x[!,xid] xid = |([:x_1,:y_1,:z_1,:roll_1,:pitch_1,:yaw_1,:thumbbend_1,:forebend_1,:middlebend_1,:ringbend_1,:littlebend_1,:x_2,:y_2,:z_2,:roll_2,:pitch_2,:yaw_2,:thumbbend_2,:forebend_2,:middlebend_2,:ringbend_2,:littlebend_2]) vid = |(1:10) end G(v) = all(v) #globally F(v) = any(v) #eventually f_lt(x, xid, v, vid) = lt(x[!,xid], v[xid][vid]) #feature is less than a constant f_lte(x, xid, v, vid) = lte(x[!,xid], v[xid][vid]) #feature is less than or equal to a constant f_gt(x, xid, v, vid) = gt(x[!,xid], v[xid][vid]) #feature is greater than a constant f_gte(x, xid, v, vid) = gte(x[!,xid], v[xid][vid]) #feature is greater than or equal to a constant #workarounds for slow dot operators: implies(v1, v2) = (a = similar(v1); a .= v2 .| .!v1) #implies not(v) = (a = similar(v); a .= .!v) #not and(v1, v2) = (a = similar(v1); a .= v1 .& v2) #and or(v1, v2) = (a = similar(v1); a .= v1 .| v2) #or lt(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .< x2) #less than lte(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .≤ x2) #less than or equal to gt(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .> x2) #greater than gte(x1, x2) = (a = Vector{Bool}(undef,length(x1)); a .= x1 .≥ x2) #greater than or equal to const v = Dict{Symbol,Vector{Float64}}() mins, maxes = minimum(X), maximum(X) for (i,xid) in enumerate(Symbol.(names(X))) v[xid] = collect(range(mins[i],stop=maxes[i],length=10)) end; p = MonteCarlo(2000, 5) using Random; Random.seed!(1) model = induce_tree(grammar, :b, p, X, y, 6); display(model; edgelabels=false) #suppress edge labels for clarity (left branch is true, right branch is false) show(model) ind = collect(1:length(X)) y_pred = classify(model, X, ind) accuracy = count(y_pred .== y[ind]) / length(ind) mvec = node_members(model, X, ind) mvec[3]' #members of node 3