001/*
002 * SPDX-FileCopyrightText: none
003 * SPDX-License-Identifier: CC0-1.0
004 */
005
006package gov.nist.secauto.oscal.lib.model.util;
007
008import gov.nist.secauto.metaschema.core.metapath.DynamicContext;
009import gov.nist.secauto.metaschema.core.metapath.ISequence;
010import gov.nist.secauto.metaschema.core.metapath.MetapathExpression;
011import gov.nist.secauto.metaschema.core.metapath.StaticContext;
012import gov.nist.secauto.metaschema.core.metapath.item.node.AbstractRecursionPreventingNodeItemVisitor;
013import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyInstanceGroupedNodeItem;
014import gov.nist.secauto.metaschema.core.metapath.item.node.IAssemblyNodeItem;
015import gov.nist.secauto.metaschema.core.metapath.item.node.IDefinitionNodeItem;
016import gov.nist.secauto.metaschema.core.metapath.item.node.IFieldNodeItem;
017import gov.nist.secauto.metaschema.core.metapath.item.node.IFlagNodeItem;
018import gov.nist.secauto.metaschema.core.metapath.item.node.IModuleNodeItem;
019import gov.nist.secauto.metaschema.core.metapath.item.node.INodeItemFactory;
020import gov.nist.secauto.metaschema.core.model.IModule;
021import gov.nist.secauto.metaschema.core.model.constraint.IAllowedValuesConstraint;
022
023import java.util.Collection;
024import java.util.LinkedHashMap;
025import java.util.LinkedList;
026import java.util.List;
027import java.util.Map;
028
029import edu.umd.cs.findbugs.annotations.NonNull;
030
031public class AllowedValueCollectingNodeItemVisitor
032    extends AbstractRecursionPreventingNodeItemVisitor<DynamicContext, Void> {
033
034  private final Map<IDefinitionNodeItem<?, ?>, NodeItemRecord> nodeItemAnalysis = new LinkedHashMap<>();
035
036  public Collection<NodeItemRecord> getAllowedValueLocations() {
037    return nodeItemAnalysis.values();
038  }
039
040  public void visit(@NonNull IModule module) {
041    DynamicContext context = new DynamicContext(
042        StaticContext.builder()
043            .defaultModelNamespace(module.getXmlNamespace())
044            .build());
045    context.disablePredicateEvaluation();
046
047    visit(INodeItemFactory.instance().newModuleNodeItem(module), context);
048  }
049
050  public void visit(@NonNull IModuleNodeItem module, @NonNull DynamicContext context) {
051
052    visitMetaschema(module, context);
053  }
054
055  private void handleAllowedValuesAtLocation(@NonNull IDefinitionNodeItem<?, ?> itemLocation, DynamicContext context) {
056    itemLocation.getDefinition().getAllowedValuesConstraints().stream()
057        .forEachOrdered(allowedValues -> {
058          String metapath = allowedValues.getTarget();
059
060          MetapathExpression path = MetapathExpression.compile(metapath, context.getStaticContext());
061          ISequence<?> result = path.evaluate(itemLocation, context);
062          result.stream().forEachOrdered(target -> {
063            handleAllowedValues(allowedValues, itemLocation, (IDefinitionNodeItem<?, ?>) target);
064          });
065        });
066  }
067
068  private void handleAllowedValues(
069      @NonNull IAllowedValuesConstraint allowedValues,
070      @NonNull IDefinitionNodeItem<?, ?> location,
071      @NonNull IDefinitionNodeItem<?, ?> target) {
072    NodeItemRecord itemRecord = nodeItemAnalysis.get(target);
073    if (itemRecord == null) {
074      itemRecord = new NodeItemRecord(target);
075      nodeItemAnalysis.put(target, itemRecord);
076    }
077
078    AllowedValuesRecord allowedValuesRecord = new AllowedValuesRecord(allowedValues, location, target);
079    itemRecord.addAllowedValues(allowedValuesRecord);
080  }
081
082  @Override
083  public Void visitFlag(IFlagNodeItem item, DynamicContext context) {
084    handleAllowedValuesAtLocation(item, context);
085    return super.visitFlag(item, context);
086  }
087
088  @Override
089  public Void visitField(IFieldNodeItem item, DynamicContext context) {
090    handleAllowedValuesAtLocation(item, context);
091    return super.visitField(item, context);
092  }
093
094  @Override
095  public Void visitAssembly(IAssemblyNodeItem item, DynamicContext context) {
096    handleAllowedValuesAtLocation(item, context);
097
098    return super.visitAssembly(item, context);
099  }
100
101  @Override
102  public Void visitAssembly(IAssemblyInstanceGroupedNodeItem item, DynamicContext context) {
103    return visitAssembly((IAssemblyNodeItem) item, context);
104  }
105
106  @Override
107  protected Void defaultResult() {
108    return null;
109  }
110
111  public static final class NodeItemRecord {
112    @NonNull
113    private final IDefinitionNodeItem<?, ?> item;
114    @NonNull
115    private final List<AllowedValuesRecord> allowedValues = new LinkedList<>();
116
117    private NodeItemRecord(@NonNull IDefinitionNodeItem<?, ?> item) {
118      this.item = item;
119    }
120
121    @NonNull
122    public IDefinitionNodeItem<?, ?> getItem() {
123      return item;
124    }
125
126    @NonNull
127    public List<AllowedValuesRecord> getAllowedValues() {
128      return allowedValues;
129    }
130
131    public void addAllowedValues(@NonNull AllowedValuesRecord record) {
132      this.allowedValues.add(record);
133    }
134  }
135
136  public static final class AllowedValuesRecord {
137    @NonNull
138    private final IAllowedValuesConstraint allowedValues;
139    @NonNull
140    private final IDefinitionNodeItem<?, ?> location;
141    @NonNull
142    private final IDefinitionNodeItem<?, ?> target;
143
144    public AllowedValuesRecord(
145        @NonNull IAllowedValuesConstraint allowedValues,
146        @NonNull IDefinitionNodeItem<?, ?> location,
147        @NonNull IDefinitionNodeItem<?, ?> target) {
148      this.allowedValues = allowedValues;
149      this.location = location;
150      this.target = target;
151    }
152
153    @NonNull
154    public IAllowedValuesConstraint getAllowedValues() {
155      return allowedValues;
156    }
157
158    @NonNull
159    public IDefinitionNodeItem<?, ?> getLocation() {
160      return location;
161    }
162
163    @NonNull
164    public IDefinitionNodeItem<?, ?> getTarget() {
165      return target;
166    }
167  }
168}